]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
sync : vulkan (skip) (llama/0)
authorGeorgi Gerganov <redacted>
Tue, 27 Aug 2024 18:48:22 +0000 (21:48 +0300)
committerGeorgi Gerganov <redacted>
Tue, 27 Aug 2024 19:01:14 +0000 (22:01 +0300)
16 files changed:
CMakeLists.txt
src/CMakeLists.txt
src/ggml-vulkan.cpp
src/vulkan-shaders/acc.comp [new file with mode: 0644]
src/vulkan-shaders/concat.comp
src/vulkan-shaders/mul_mat_vec.comp
src/vulkan-shaders/mul_mat_vec_nc.comp
src/vulkan-shaders/mul_mat_vec_p021.comp
src/vulkan-shaders/mul_mat_vec_q2_k.comp
src/vulkan-shaders/mul_mat_vec_q3_k.comp
src/vulkan-shaders/mul_mat_vec_q4_k.comp
src/vulkan-shaders/mul_mat_vec_q5_k.comp
src/vulkan-shaders/mul_mat_vec_q6_k.comp
src/vulkan-shaders/mul_mm.comp
src/vulkan-shaders/repeat.comp [new file with mode: 0644]
src/vulkan-shaders/vulkan-shaders-gen.cpp

index 357e7e51e91521d1839dcbb5178d0fc394fa113c..cc16858849783fbe276979eebea331a17a8aa737 100644 (file)
@@ -135,6 +135,7 @@ option(GGML_VULKAN                          "ggml: use Vulkan"
 option(GGML_VULKAN_CHECK_RESULTS            "ggml: run Vulkan op checks"                      OFF)
 option(GGML_VULKAN_DEBUG                    "ggml: enable Vulkan debug output"                OFF)
 option(GGML_VULKAN_MEMORY_DEBUG             "ggml: enable Vulkan memory debug output"         OFF)
+option(GGML_VULKAN_PERF                     "ggml: enable Vulkan perf output"                 OFF)
 option(GGML_VULKAN_VALIDATE                 "ggml: enable Vulkan validation"                  OFF)
 option(GGML_VULKAN_RUN_TESTS                "ggml: run Vulkan tests"                          OFF)
 option(GGML_KOMPUTE                         "ggml: use Kompute"                               OFF)
index 574172b4efa9a974c2bc6e7cfa0a56902ad933e0..ff84b9bb5f0f2c360df7713e20a11c6c9fcbca96 100644 (file)
@@ -612,6 +612,10 @@ if (GGML_VULKAN)
             add_compile_definitions(GGML_VULKAN_MEMORY_DEBUG)
         endif()
 
+        if (GGML_VULKAN_PERF)
+            add_compile_definitions(GGML_VULKAN_PERF)
+        endif()
+
         if (GGML_VULKAN_VALIDATE)
             add_compile_definitions(GGML_VULKAN_VALIDATE)
         endif()
index 9e3074d58655bc54cc9a4dce022906dd5a509d02..ca4f44cf75615afdd710c6c077fdd89b758df9e2 100644 (file)
@@ -1,6 +1,6 @@
 #include "ggml-vulkan.h"
 #include <vulkan/vulkan_core.h>
-#ifdef GGML_VULKAN_RUN_TESTS
+#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF)
 #include <chrono>
 #endif
 
@@ -17,6 +17,7 @@
 #include <memory>
 #include <limits>
 #include <map>
+#include <unordered_map>
 #include <memory>
 #include <mutex>
 
@@ -34,9 +35,7 @@
 #define VK_VENDOR_ID_INTEL 0x8086
 #define VK_VENDOR_ID_NVIDIA 0x10de
 
-#define VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN 0
-#define VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI 1
-#define VK_DEVICE_DESCRIPTOR_POOL_MODE_SINGLE 2
+#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 32
 
 #define GGML_VK_MAX_NODES 8192
 
@@ -74,6 +73,8 @@ struct vk_queue {
     std::vector<vk::CommandBuffer> cmd_buffers;
 
     vk::PipelineStageFlags stage_flags;
+
+    bool transfer_only;
 };
 
 struct vk_pipeline_struct {
@@ -133,6 +134,9 @@ static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = {
 #ifdef GGML_VULKAN_MEMORY_DEBUG
 class vk_memory_logger;
 #endif
+#ifdef GGML_VULKAN_PERF
+class vk_perf_logger;
+#endif
 static void ggml_vk_destroy_buffer(vk_buffer& buf);
 
 struct vk_device_struct {
@@ -148,7 +152,6 @@ struct vk_device_struct {
     vk_queue compute_queue;
     vk_queue transfer_queue;
     bool single_queue;
-    uint32_t descriptor_set_mode;
     uint32_t subgroup_size;
     bool uma;
 
@@ -177,6 +180,7 @@ struct vk_device_struct {
     vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
     vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
     vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
+    vk_pipeline pipeline_acc_f32;
     vk_pipeline pipeline_add_f32, pipeline_add_f16_f32_f16;
     vk_pipeline pipeline_mul_f32;
     vk_pipeline pipeline_div_f32;
@@ -188,6 +192,7 @@ struct vk_device_struct {
     vk_pipeline pipeline_cos_f32;
     vk_pipeline pipeline_clamp_f32;
     vk_pipeline pipeline_pad_f32;
+    vk_pipeline pipeline_repeat_f32;
     vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
     vk_pipeline pipeline_norm_f32;
     vk_pipeline pipeline_group_norm_f32;
@@ -207,7 +212,8 @@ struct vk_device_struct {
     vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
     vk_pipeline pipeline_timestep_embedding_f32;
 
-    std::vector<vk_pipeline_ref> pipelines;
+    std::unordered_map<std::string, vk_pipeline_ref> pipelines;
+    std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
 
     std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;
 
@@ -219,6 +225,9 @@ struct vk_device_struct {
 #ifdef GGML_VULKAN_MEMORY_DEBUG
     std::unique_ptr<vk_memory_logger> memory_logger;
 #endif
+#ifdef GGML_VULKAN_PERF
+    std::unique_ptr<vk_perf_logger> perf_logger;
+#endif
 
     ~vk_device_struct() {
         VK_LOG_DEBUG("destroy device " << name);
@@ -233,11 +242,11 @@ struct vk_device_struct {
         }
 
         for (auto& pipeline : pipelines) {
-            if (pipeline.expired()) {
+            if (pipeline.second.expired()) {
                 continue;
             }
 
-            vk_pipeline pl = pipeline.lock();
+            vk_pipeline pl = pipeline.second.lock();
             ggml_vk_destroy_pipeline(device, pl);
         }
         pipelines.clear();
@@ -481,6 +490,48 @@ private:
 #define VK_LOG_MEMORY(msg) ((void) 0)
 #endif // GGML_VULKAN_MEMORY_DEBUG
 
+#if defined(GGML_VULKAN_PERF)
+
+class vk_perf_logger {
+public:
+    void print_timings() {
+        std::cerr << "----------------\nVulkan Timings:" << std::endl;
+        for (const auto& t : timings) {
+            uint64_t total = 0;
+            for (const auto& time : t.second) {
+                total += time;
+            }
+            std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " ms" << std::endl;
+        }
+
+        timings.clear();
+    }
+
+    void log_timing(const ggml_tensor * node, uint64_t time) {
+        if (node->op == GGML_OP_UNARY) {
+            timings[ggml_unary_op_name(ggml_get_unary_op(node))].push_back(time);
+            return;
+        }
+        if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
+            const uint64_t m = node->src[0]->ne[1];
+            const uint64_t n = node->src[1]->ne[1];
+            const uint64_t k = node->src[1]->ne[0];
+            std::string name = ggml_op_name(node->op);
+            if (n == 1) {
+                name += "_VEC m=" + std::to_string(m) + " k=" + std::to_string(k);
+            } else {
+                name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k);
+            }
+            timings[name].push_back(time);
+            return;
+        }
+        timings[ggml_op_name(node->op)].push_back(time);
+    }
+private:
+    std::map<std::string, std::vector<uint64_t>> timings;
+};
+#endif // GGML_VULKAN_PERF
+
 struct ggml_backend_vk_context {
     std::string name;
 
@@ -491,9 +542,6 @@ struct ggml_backend_vk_context {
     size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k;
     vk_buffer prealloc_x, prealloc_y, prealloc_split_k;
     vk::Fence fence;
-    vk_buffer staging;
-    size_t staging_size;
-    size_t staging_offset;
 
     vk_buffer buffer_pool[MAX_VK_BUFFERS];
 
@@ -597,35 +645,9 @@ static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, co
     descriptor_set_layout_create_info.setPNext(&dslbfci);
     pipeline->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info);
 
-    // Check if device supports multiple descriptors per pool
-    if (device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN) {
-        const uint32_t alloc_count = 2;
-
-        // Try allocating multiple sets from one pool
-        // This fails on AMD for some reason, so add a fall back to allocating one pool per set
-        vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count);
-        vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, alloc_count, descriptor_pool_size);
-        vk::DescriptorPool pool = device->device.createDescriptorPool(descriptor_pool_create_info);
-
-        std::vector<vk::DescriptorSetLayout> layouts(alloc_count);
-        for (uint32_t i = 0; i < alloc_count; i++) {
-            layouts[i] = pipeline->dsl;
-        }
-        try {
-            vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pool, alloc_count, layouts.data());
-            std::vector<vk::DescriptorSet> sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info);
-        } catch(vk::OutOfPoolMemoryError const&) {
-            device->descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_SINGLE;
-        }
-
-        device->device.destroyDescriptorPool(pool);
-    }
-
-    if (device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI) {
-        vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count);
-        vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, 128, descriptor_pool_size);
-        pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info));
-    }
+    vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count * VK_DEVICE_DESCRIPTOR_POOL_SIZE);
+    vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size);
+    pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info));
 
     pipeline->descriptor_set_idx = 0;
 
@@ -659,7 +681,7 @@ static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, co
         pipeline->layout);
     pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
 
-    device->pipelines.push_back(pipeline);
+    device->pipelines.insert({ pipeline->name, pipeline });
 }
 
 static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
@@ -680,34 +702,49 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline)
     device.destroyPipeline(pipeline->pipeline);
 }
 
-static void ggml_pipeline_allocate_descriptor_sets(vk_device& device, vk_pipeline& pipeline, uint32_t n) {
-    VK_LOG_DEBUG("ggml_pipeline_allocate_descriptor_sets(" << pipeline->name << ", " << n << ")");
-    if (pipeline->descriptor_sets.size() >= pipeline->descriptor_set_idx + n) {
-        // Enough descriptors are available
-        return;
-    }
+static void ggml_pipeline_request_descriptor_sets(vk_device& device, vk_pipeline& pipeline, uint32_t n) {
+    VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")");
+    device->pipeline_descriptor_set_requirements[pipeline->name] += n;
+}
 
+static void ggml_pipeline_allocate_descriptor_sets(vk_device& device) {
     std::lock_guard<std::mutex> guard(device->mutex);
 
-    if (device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI) {
-        const uint32_t alloc_count = pipeline->descriptor_set_idx + n - pipeline->descriptor_sets.size();
+    for (auto& pair : device->pipeline_descriptor_set_requirements) {
+        vk_pipeline pipeline = device->pipelines.at(pair.first).lock();
+        const uint64_t n = pair.second;
+
+        VK_LOG_DEBUG("ggml_pipeline_allocate_descriptor_sets(" << pipeline->name << ", " << n << ")");
 
-        std::vector<vk::DescriptorSetLayout> layouts(alloc_count);
-        for (uint32_t i = 0; i < alloc_count; i++) {
-            layouts[i] = pipeline->dsl;
+        if (pipeline->descriptor_sets.size() >= pipeline->descriptor_set_idx + n) {
+            // Enough descriptors are available
+            continue;
         }
-        vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[0], alloc_count, layouts.data());
-        std::vector<vk::DescriptorSet> sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info);
-        pipeline->descriptor_sets.insert(pipeline->descriptor_sets.end(), sets.begin(), sets.end());
-    } else {
-        for (uint32_t i = pipeline->descriptor_sets.size(); i < pipeline->descriptor_set_idx + n; i++) {
-            vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count);
-            vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, 1, descriptor_pool_size);
-            pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info));
 
-            vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[i], 1, &pipeline->dsl);
+        uint32_t to_alloc = pipeline->descriptor_set_idx + n - pipeline->descriptor_sets.size();
+        uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - pipeline->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE;
+        uint32_t pool_idx = pipeline->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE;
+
+        while (to_alloc > 0) {
+            const uint32_t alloc_count = std::min(pool_remaining, to_alloc);
+            to_alloc -= alloc_count;
+            pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE;
+
+            if (pool_idx >= pipeline->descriptor_pools.size()) {
+                vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count * VK_DEVICE_DESCRIPTOR_POOL_SIZE);
+                vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size);
+                pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info));
+            }
+
+            std::vector<vk::DescriptorSetLayout> layouts(alloc_count);
+            for (uint32_t i = 0; i < alloc_count; i++) {
+                layouts[i] = pipeline->dsl;
+            }
+            vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[pool_idx], alloc_count, layouts.data());
             std::vector<vk::DescriptorSet> sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info);
-            pipeline->descriptor_sets.push_back(sets[0]);
+            pipeline->descriptor_sets.insert(pipeline->descriptor_sets.end(), sets.begin(), sets.end());
+
+            pool_idx++;
         }
     }
 }
@@ -868,11 +905,12 @@ static uint32_t ggml_vk_find_queue_family_index(std::vector<vk::QueueFamilyPrope
     abort();
 }
 
-static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags) {
+static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags, bool transfer_only) {
     VK_LOG_DEBUG("ggml_vk_create_queue()");
     std::lock_guard<std::mutex> guard(device->mutex);
 
     q.queue_family_index = queue_family_index;
+    q.transfer_only = transfer_only;
 
     vk::CommandPoolCreateInfo command_pool_create_info_compute(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), queue_family_index);
     q.pool = device->device.createCommandPool(command_pool_create_info_compute);
@@ -1069,13 +1107,16 @@ static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) {
 
 static void ggml_vk_sync_buffers(vk_context& ctx) {
     VK_LOG_DEBUG("ggml_vk_sync_buffers()");
+
+    const bool transfer_queue = ctx->q->transfer_only;
+
     ctx->s->buffer.pipelineBarrier(
         ctx->q->stage_flags,
         ctx->q->stage_flags,
         {},
         { {
-          {vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite},
-          {vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite}
+          { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) },
+          { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) }
         } },
         {},
         {}
@@ -1649,6 +1690,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
 
+    ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
+
     ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
 
@@ -1668,6 +1711,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
+    ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+
     ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
@@ -1707,6 +1752,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
 #ifdef GGML_VULKAN_MEMORY_DEBUG
         device->memory_logger = std::unique_ptr<vk_memory_logger>(new vk_memory_logger());
 #endif
+#ifdef GGML_VULKAN_PERF
+        device->perf_logger = std::unique_ptr<vk_perf_logger>(new vk_perf_logger());
+#endif
 
         size_t dev_num = vk_instance.device_indices[idx];
 
@@ -1837,17 +1885,15 @@ static vk_device ggml_vk_get_device(size_t idx) {
         device_create_info.setPNext(&device_features2);
         device->device = device->physical_device.createDevice(device_create_info);
 
-        device->descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN;
-
         // Queues
-        ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer });
+        ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false);
 
         // Shaders
         ggml_vk_load_shaders(device);
 
         if (!device->single_queue) {
             const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
-            ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer });
+            ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true);
         } else {
             // TODO: Use pointer or reference to avoid copy
             device->transfer_queue = device->compute_queue;
@@ -2134,9 +2180,6 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
 
     ctx->fence = ctx->device->device.createFence({});
 
-    ctx->staging_size = 0;
-    ctx->staging_offset = 0;
-
 #ifdef GGML_VULKAN_CHECK_RESULTS
     const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS");
     vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks));
@@ -2569,23 +2612,15 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
         return;
     }
 
-    // Staging buffer required
-    vk_buffer staging = ctx->staging;
-    size_t staging_offset = ctx->staging_offset;
-    const size_t copy_size = ts*ne/bs;
-    if (ctx->staging->size < ctx->staging_offset + copy_size) {
-        if (sync_staging) {
-            // Create temporary larger buffer
-            ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size);
-
-            staging = ctx->device->sync_staging;
-            staging_offset = 0;
-        } else {
-            GGML_ABORT("fatal error");
-        }
+    if (!sync_staging) {
+        GGML_ABORT("Asynchronous write to non-pinned memory not supported");
     }
 
-    VkBufferCopy buf_copy{ staging_offset, offset, copy_size };
+    // Staging buffer required
+    vk_buffer& staging = ctx->device->sync_staging;
+    const uint64_t copy_size = ts*ne/bs;
+    ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size);
+    VkBufferCopy buf_copy{ 0, offset, copy_size };
 
     ggml_vk_sync_buffers(subctx);
     vkCmdCopyBuffer(subctx->s->buffer, staging->buffer, dst->buffer, 1, &buf_copy);
@@ -2594,14 +2629,14 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
         for (uint64_t i2 = 0; i2 < ne2; i2++) {
             // Find longest contiguous slice
             if (ne1*nb1 == dstnb2) {
-                deferred_memcpy((uint8_t *)staging->ptr + staging_offset + i3*dstnb3 + i2*dstnb2, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2, dstnb2, &subctx->in_memcpys);
+                deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2, dstnb2, &subctx->in_memcpys);
             } else {
                 for (uint64_t i1 = 0; i1 < ne1; i1++) {
                     if (ne0*nb0/bs == dstnb1) {
-                        deferred_memcpy((uint8_t *)staging->ptr + staging_offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2 + i1*nb1, dstnb1, &subctx->in_memcpys);
+                        deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2 + i1*nb1, dstnb1, &subctx->in_memcpys);
                     } else {
                         const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1;
-                        const uint64_t d_off = staging_offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1;
+                        const uint64_t d_off = i3*dstnb3 + i2*dstnb2 + i1*dstnb1;
                         for (uint64_t i0 = 0; i0 < ne0; i0++) {
                             deferred_memcpy((uint8_t *)staging->ptr + d_off + i0*dstnb0, (const uint8_t *) tensor->data + s_off + i0*nb0, dstnb0, &subctx->in_memcpys);
                         }
@@ -2612,7 +2647,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
     }
 }
 
-static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
+static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) {
     VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")");
     // Buffer is already mapped
     if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
@@ -2647,21 +2682,18 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
     }
     VK_LOG_DEBUG("STAGING");
 
+    if (!sync_staging) {
+        GGML_ABORT("Asynchronous write to non-pinned memory not supported");
+    }
+
     // Staging buffer required
     const size_t copy_size = width*height;
-    if (staging_buffer == nullptr || staging_buffer->size < staging_offset + copy_size) {
-        if (sync_staging) {
-            ggml_vk_ensure_sync_staging_buffer(dst->device, copy_size);
+    ggml_vk_ensure_sync_staging_buffer(dst->device, copy_size);
 
-            staging_buffer = dst->device->sync_staging;
-            staging_offset = 0;
-        } else {
-            GGML_ABORT("fatal error");
-        }
-    }
+    vk_buffer& staging_buffer = dst->device->sync_staging;
 
     VkBufferCopy buf_copy = {
-        staging_offset,
+        0,
         offset,
         copy_size};
 
@@ -2669,17 +2701,17 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
     vkCmdCopyBuffer(subctx->s->buffer, staging_buffer->buffer, dst->buffer, 1, &buf_copy);
 
     if (width == spitch) {
-        deferred_memcpy((uint8_t *)staging_buffer->ptr + staging_offset, src, width * height, &subctx->in_memcpys);
+        deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys);
     } else {
         for (size_t i = 0; i < height; i++) {
-            deferred_memcpy((uint8_t *)staging_buffer->ptr + staging_offset + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys);
+            deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys);
         }
     }
 }
 
-static void ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
+static void ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) {
     VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")");
-    return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, staging_buffer, staging_offset, sync_staging);
+    return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging);
 }
 
 static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) {
@@ -2694,7 +2726,7 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void *
     } else {
         vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue);
         ggml_vk_ctx_begin(dst->device, subctx);
-        ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, nullptr, 0, true);
+        ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true);
         ggml_vk_ctx_end(subctx);
 
         for (auto& cpy : subctx->in_memcpys) {
@@ -2712,7 +2744,7 @@ static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src
     ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1);
 }
 
-static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
+static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) {
     VK_LOG_DEBUG("ggml_vk_buffer_read_2d_async(offset=" << offset << ", width=" << width << ", height=" << height << ")");
     GGML_ASSERT(width > 0);
     GGML_ASSERT(height > 0);
@@ -2749,18 +2781,15 @@ static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size
     }
     VK_LOG_DEBUG("STAGING");
 
+    if (!sync_staging) {
+        GGML_ABORT("Asynchronous read from non-pinned memory not supported");
+    }
+
     // Fall back to staging buffer
     const size_t copy_size = dpitch * height;
-    if (staging_buffer == nullptr || staging_buffer->size < staging_offset + copy_size) {
-        if (sync_staging) {
-            // Create temporary larger buffer
-            ggml_vk_ensure_sync_staging_buffer(src->device, copy_size);
+    ggml_vk_ensure_sync_staging_buffer(src->device, copy_size);
 
-            staging_buffer = src->device->sync_staging;
-        } else {
-            GGML_ABORT("fatal error");
-        }
-    }
+    vk_buffer& staging_buffer = src->device->sync_staging;
 
     ggml_vk_sync_buffers(subctx);
     subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices);
@@ -2768,8 +2797,8 @@ static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size
     deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys);
 }
 
-static void ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t size, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
-    return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, staging_buffer, staging_offset, sync_staging);
+static void ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t size, bool sync_staging = false) {
+    return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, sync_staging);
 }
 
 static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) {
@@ -2781,7 +2810,7 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_
     } else {
         vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue);
         ggml_vk_ctx_begin(src->device, subctx);
-        ggml_vk_buffer_read_async(subctx, src, offset, dst, size, nullptr, 0, true);
+        ggml_vk_buffer_read_async(subctx, src, offset, dst, size, true);
         ggml_vk_ctx_end(subctx);
 
         ggml_vk_submit(subctx, src->device->fence);
@@ -2982,10 +3011,11 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
     ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, { ne, 1, 1 });
 }
 
-static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
     VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
     std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
-    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
+    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
+    std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
     GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);  // NOLINT
     GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);  // NOLINT
 
@@ -3059,6 +3089,56 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
     const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
     const uint64_t d_sz = sizeof(float) * d_ne;
 
+    vk_pipeline to_fp16_vk_0 = nullptr;
+    vk_pipeline to_fp16_vk_1 = nullptr;
+
+    if (x_non_contig) {
+        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, GGML_TYPE_F16);
+    } else {
+        to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
+    }
+    if (y_non_contig) {
+        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, GGML_TYPE_F16);
+    } else {
+        to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
+    }
+    GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT
+    GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
+
+    if (dryrun) {
+        const uint64_t x_sz_upd = x_sz * ne02 * ne03;
+        const uint64_t y_sz_upd = y_sz * ne12 * ne13;
+        const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * 4 : 0;
+        if (
+                (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
+                (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) ||
+                (split_k > 1 && split_k_size > ctx->device->max_memory_allocation_size)) {
+            GGML_ABORT("Requested preallocation size is too large");
+        }
+        if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
+            ctx->prealloc_size_x = x_sz_upd;
+        }
+        if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
+            ctx->prealloc_size_y = y_sz_upd;
+        }
+        if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {
+            ctx->prealloc_size_split_k = split_k_size;
+        }
+
+        // Request descriptor sets
+        ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
+        if (qx_needs_dequant) {
+            ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
+        }
+        if (qy_needs_dequant) {
+            ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
+        }
+        if (split_k > 1) {
+            ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1);
+        }
+        return;
+    }
+
     vk_buffer d_D = extra->buffer_gpu.lock();
     const uint64_t d_buf_offset = extra->offset + dst->view_offs;
     GGML_ASSERT(d_D != nullptr);
@@ -3094,34 +3174,6 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
         GGML_ASSERT(qy_sz == y_sz);
     }
 
-    vk_pipeline to_fp16_vk_0 = nullptr;
-    vk_pipeline to_fp16_vk_1 = nullptr;
-
-    if (x_non_contig) {
-        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, GGML_TYPE_F16);
-    } else {
-        to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
-    }
-    if (y_non_contig) {
-        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, GGML_TYPE_F16);
-    } else {
-        to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
-    }
-    GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT
-    GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
-
-    // Allocate descriptor sets
-    ggml_pipeline_allocate_descriptor_sets(ctx->device, pipeline, 1);
-    if (qx_needs_dequant) {
-        ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
-    }
-    if (qy_needs_dequant) {
-        ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
-    }
-    if (split_k > 1) {
-        ggml_pipeline_allocate_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1);
-    }
-
     if (x_non_contig) {
         ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
     } else if (qx_needs_dequant) {
@@ -3155,10 +3207,11 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
     );  // NOLINT
 }
 
-static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
     VK_LOG_DEBUG("ggml_vk_mul_mat_vec_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
     std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
-    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
+    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
+    std::cerr << "), " << (dryrun ? "dryrun" : "") << "),)");
     GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);  // NOLINT
     GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);  // NOLINT
 
@@ -3222,6 +3275,47 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
     const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
     const uint64_t d_sz = sizeof(float) * d_ne;
 
+    vk_pipeline to_fp16_vk_0 = nullptr;
+    vk_pipeline to_fp16_vk_1 = nullptr;
+    if (x_non_contig) {
+        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, src0->type);
+    }
+    if (y_non_contig) {
+        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, src1->type);
+    } else {
+        to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
+    }
+    vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type);
+    GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT
+    GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
+    GGML_ASSERT(dmmv != nullptr);
+
+    if (dryrun) {
+        const uint64_t x_sz_upd = x_sz * ne02 * ne03;
+        const uint64_t y_sz_upd = y_sz * ne12 * ne13;
+        if (
+                (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
+                (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) {
+            GGML_ABORT("Requested preallocation size is too large");
+        }
+        if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
+            ctx->prealloc_size_x = x_sz_upd;
+        }
+        if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
+            ctx->prealloc_size_y = y_sz_upd;
+        }
+
+        // Request descriptor sets
+        if (qx_needs_dequant) {
+            ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
+        }
+        if (qy_needs_dequant) {
+            ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
+        }
+        ggml_pipeline_request_descriptor_sets(ctx->device, dmmv, 1);
+        return;
+    }
+
     vk_buffer d_D = extra->buffer_gpu.lock();
     const uint64_t d_buf_offset = extra->offset + dst->view_offs;
     GGML_ASSERT(d_D != nullptr);
@@ -3254,30 +3348,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
         GGML_ASSERT(qy_sz == y_sz);
     }
 
-    vk_pipeline to_fp16_vk_0 = nullptr;
-    vk_pipeline to_fp16_vk_1 = nullptr;
-    if (x_non_contig) {
-        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, src0->type);
-    }
-    if (y_non_contig) {
-        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, src1->type);
-    } else {
-        to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
-    }
-    vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type);
-    GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT
-    GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
-    GGML_ASSERT(dmmv != nullptr);
-
-    // Allocate descriptor sets
-    if (qx_needs_dequant) {
-        ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
-    }
-    if (qy_needs_dequant) {
-        ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13);
-    }
-    ggml_pipeline_allocate_descriptor_sets(ctx->device, dmmv, ne12 * ne13);
-
     if (x_non_contig) {
         GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
         ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
@@ -3320,10 +3390,11 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
                               sizeof(vk_mat_vec_push_constants), &pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
 }
 
-static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
     VK_LOG_DEBUG("ggml_vk_mul_mat_p021_f16_f32(" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
     std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
-    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
+    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
+    std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
     GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
     GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]);  // NOLINT
     GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]);  // NOLINT
@@ -3364,6 +3435,12 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
     const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
     const uint64_t d_sz = sizeof(float) * d_ne;
 
+    if (dryrun) {
+        // Request descriptor sets
+        ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1);
+        return;
+    }
+
     vk_buffer d_D = extra->buffer_gpu.lock();
     const uint64_t d_buf_offset = extra->offset + dst->view_offs;
     GGML_ASSERT(d_D != nullptr);
@@ -3376,9 +3453,6 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
         GGML_ASSERT(d_Qx != nullptr);
     }
 
-    // Allocate descriptor sets
-    ggml_pipeline_allocate_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1);
-
     const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
     const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset;
 
@@ -3391,10 +3465,11 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
     ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
 }
 
-static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
     VK_LOG_DEBUG("ggml_vk_mul_mat_nc_f16_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
     std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
-    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
+    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
+    std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
     GGML_ASSERT(!ggml_is_transposed(src0));
     GGML_ASSERT(!ggml_is_transposed(src1));
     GGML_ASSERT(!ggml_is_permuted(src0));
@@ -3439,6 +3514,12 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
     const uint64_t qy_sz = ggml_nbytes(src1);
     const uint64_t d_sz = sizeof(float) * d_ne;
 
+    if (dryrun) {
+        // Request descriptor sets
+        ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1);
+        return;
+    }
+
     vk_buffer d_D = extra->buffer_gpu.lock();
     const uint64_t d_buf_offset = extra->offset + dst->view_offs;
     GGML_ASSERT(d_D != nullptr);
@@ -3451,9 +3532,6 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
         GGML_ASSERT(d_Qx != nullptr);
     }
 
-    // Allocate descriptor sets
-    ggml_pipeline_allocate_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1);
-
     const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
     const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset;
 
@@ -3467,20 +3545,20 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
         { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
 }
 
-static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
     VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")");
     if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1) {
-        ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, src0, src1, dst);
+        ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
     } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1) {
-        ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst);
+        ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
     } else if (dst->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
-        ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst);
+        ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
     } else {
-        ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst);
+        ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
     }
 }
 
-static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
+static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) {
     VK_LOG_DEBUG("ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
     std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
     std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
@@ -3570,6 +3648,48 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
     const uint64_t ids_sz = nbi2;
     const uint64_t d_sz = sizeof(float) * d_ne;
 
+    vk_pipeline to_fp16_vk_0 = nullptr;
+    vk_pipeline to_fp16_vk_1 = nullptr;
+
+    if (x_non_contig) {
+        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, GGML_TYPE_F16);
+    } else {
+        to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
+    }
+    if (y_non_contig) {
+        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, GGML_TYPE_F16);
+    } else {
+        to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
+    }
+    GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT
+    GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
+
+    if (dryrun) {
+        const uint64_t x_sz_upd = x_sz * ne02 * ne03;
+        const uint64_t y_sz_upd = y_sz * ne12 * ne13;
+        if (
+                (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
+                (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) {
+            GGML_ABORT("Requested preallocation size is too large");
+        }
+        if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
+            ctx->prealloc_size_x = x_sz_upd;
+        }
+        if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
+            ctx->prealloc_size_y = y_sz_upd;
+        }
+
+        // Request descriptor sets
+        ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
+        if (qx_needs_dequant) {
+            ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
+        }
+        if (qy_needs_dequant) {
+            ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
+        }
+        return;
+    }
+
     vk_buffer d_D = extra->buffer_gpu.lock();
     const uint64_t d_buf_offset = extra->offset + dst->view_offs;
     GGML_ASSERT(d_D != nullptr);
@@ -3609,31 +3729,6 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
         GGML_ASSERT(qy_sz == y_sz);
     }
 
-    vk_pipeline to_fp16_vk_0 = nullptr;
-    vk_pipeline to_fp16_vk_1 = nullptr;
-
-    if (x_non_contig) {
-        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, GGML_TYPE_F16);
-    } else {
-        to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
-    }
-    if (y_non_contig) {
-        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, GGML_TYPE_F16);
-    } else {
-        to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
-    }
-    GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT
-    GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
-
-    // Allocate descriptor sets
-    ggml_pipeline_allocate_descriptor_sets(ctx->device, pipeline, 1);
-    if (qx_needs_dequant) {
-        ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
-    }
-    if (qy_needs_dequant) {
-        ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
-    }
-
     if (x_non_contig) {
         ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
     } else if (qx_needs_dequant) {
@@ -3668,11 +3763,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
     );  // NOLINT
 }
 
-static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
+static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) {
     VK_LOG_DEBUG("ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
     std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
     std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
-    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
+    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
+    std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
     GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);  // NOLINT
     GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);  // NOLINT
     GGML_ASSERT(ids->type == GGML_TYPE_I32);
@@ -3746,6 +3842,47 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
     const uint64_t ids_sz = nbi2;
     const uint64_t d_sz = sizeof(float) * d_ne;
 
+    vk_pipeline to_fp16_vk_0 = nullptr;
+    vk_pipeline to_fp16_vk_1 = nullptr;
+    if (x_non_contig) {
+        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, src0->type);
+    }
+    if (y_non_contig) {
+        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, src1->type);
+    } else {
+        to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
+    }
+    vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type);
+    GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT
+    GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
+    GGML_ASSERT(dmmv != nullptr);
+
+    if (dryrun) {
+        const uint64_t x_sz_upd = x_sz * ne02 * ne03;
+        const uint64_t y_sz_upd = y_sz * ne12 * ne13;
+        if (
+                (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
+                (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) {
+            GGML_ABORT("Requested preallocation size is too large");
+        }
+        if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
+            ctx->prealloc_size_x = x_sz_upd;
+        }
+        if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
+            ctx->prealloc_size_y = y_sz_upd;
+        }
+
+        // Request descriptor sets
+        if (qx_needs_dequant) {
+            ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
+        }
+        if (qy_needs_dequant) {
+            ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
+        }
+        ggml_pipeline_request_descriptor_sets(ctx->device, dmmv, 1);
+        return;
+    }
+
     vk_buffer d_D = extra->buffer_gpu.lock();
     const uint64_t d_buf_offset = extra->offset + dst->view_offs;
     GGML_ASSERT(d_D != nullptr);
@@ -3783,30 +3920,6 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
         GGML_ASSERT(qy_sz == y_sz);
     }
 
-    vk_pipeline to_fp16_vk_0 = nullptr;
-    vk_pipeline to_fp16_vk_1 = nullptr;
-    if (x_non_contig) {
-        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, src0->type);
-    }
-    if (y_non_contig) {
-        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, src1->type);
-    } else {
-        to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
-    }
-    vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type);
-    GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT
-    GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
-    GGML_ASSERT(dmmv != nullptr);
-
-    // Allocate descriptor sets
-    if (qx_needs_dequant) {
-        ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
-    }
-    if (qy_needs_dequant) {
-        ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13);
-    }
-    ggml_pipeline_allocate_descriptor_sets(ctx->device, dmmv, ne12 * ne13);
-
     if (x_non_contig) {
         GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
         ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
@@ -3845,85 +3958,15 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
         sizeof(vk_mat_vec_id_push_constants), &pc, { groups_x, (uint32_t)nei0, groups_z });
 }
 
-static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
+static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
     VK_LOG_DEBUG("ggml_vk_mul_mat_id(" << src0 << ", " << src1 << ", " << src2 << ", " << dst << ")");
     if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
-        ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst);
+        ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
     } else {
-        ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst);
-    }
-}
-
-static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    VK_LOG_DEBUG("ggml_vk_op_repeat(" << src0 << ", " << src1 << ", " << dst << ")");
-    const uint64_t ne0 = dst->ne[0];
-    const uint64_t ne1 = dst->ne[1];
-    const uint64_t ne2 = dst->ne[2];
-    const uint64_t ne3 = dst->ne[3];
-
-    const uint64_t ne00 = src0->ne[0];
-    const uint64_t ne01 = src0->ne[1];
-    const uint64_t ne02 = src0->ne[2];
-    const uint64_t ne03 = src0->ne[3];
-
-    const uint64_t nb0 = dst->nb[0];
-    const uint64_t nb1 = dst->nb[1];
-    const uint64_t nb2 = dst->nb[2];
-    const uint64_t nb3 = dst->nb[3];
-
-    const uint64_t nb00 = src0->nb[0];
-    const uint64_t nb01 = src0->nb[1];
-    const uint64_t nb02 = src0->nb[2];
-    const uint64_t nb03 = src0->nb[3];
-
-    // guaranteed to be an integer due to the check in ggml_can_repeat
-    const uint64_t nr0 = ne0/ne00;
-    const uint64_t nr1 = ne1/ne01;
-    const uint64_t nr2 = ne2/ne02;
-    const uint64_t nr3 = ne3/ne03;
-
-    // TODO: support for transposed / permuted tensors
-    GGML_ASSERT(nb0  == sizeof(float));
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
-    ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
-
-    const vk_buffer src_buf = extra_src0->buffer_gpu.lock();
-    const uint64_t src_offset = extra_src0->offset + src0->view_offs;
-    vk_buffer dst_buf = extra->buffer_gpu.lock();
-    const uint64_t dst_offset = extra->offset + dst->view_offs;
-
-    std::vector<vk::BufferCopy> copies;
-
-    for                         (uint64_t i3 = 0; i3 < nr3;  i3++) {
-        for                     (uint64_t k3 = 0; k3 < ne03; k3++) {
-            for                 (uint64_t i2 = 0; i2 < nr2;  i2++) {
-                for             (uint64_t k2 = 0; k2 < ne02; k2++) {
-                    for         (uint64_t i1 = 0; i1 < nr1;  i1++) {
-                        for     (uint64_t k1 = 0; k1 < ne01; k1++) {
-                            for (uint64_t i0 = 0; i0 < nr0;  i0++) {
-                                copies.push_back({
-                                    src_offset + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01,
-                                    dst_offset + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0,
-                                    ne00*nb0,
-                                });
-                            }
-                        }
-                    }
-                }
-            }
-        }
+        ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
     }
-
-    ggml_vk_sync_buffers(subctx);
-    subctx->s->buffer.copyBuffer(src_buf->buffer, dst_buf->buffer, copies);
-
-    GGML_UNUSED(ctx);
-    GGML_UNUSED(src1);
 }
 
-
 static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
     switch (op) {
     case GGML_OP_GET_ROWS:
@@ -3935,6 +3978,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_get_rows_f32[src0->type];
         }
         return nullptr;
+    case GGML_OP_ACC:
+        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+            return ctx->device->pipeline_acc_f32;
+        }
+        return nullptr;
     case GGML_OP_ADD:
         if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
             return ctx->device->pipeline_add_f32;
@@ -3999,6 +4047,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_pad_f32;
         }
         return nullptr;
+    case GGML_OP_REPEAT:
+        if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) {
+            return ctx->device->pipeline_repeat_f32;
+        }
+        return nullptr;
     case GGML_OP_CPY:
     case GGML_OP_CONT:
     case GGML_OP_DUP:
@@ -4121,15 +4174,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
     GGML_UNUSED(src2);
 }
 
-static ggml_vk_func_t ggml_vk_op_get_func(ggml_op op) {
-    switch(op) {
-    case GGML_OP_REPEAT:
-        return ggml_vk_op_repeat;
-    default:
-        return nullptr;
-    }
-}
-
 static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
     switch (op) {
     case GGML_OP_CPY:
@@ -4145,6 +4189,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
     case GGML_OP_COS:
     case GGML_OP_CLAMP:
     case GGML_OP_PAD:
+    case GGML_OP_REPEAT:
         return true;
     default:
         return false;
@@ -4152,7 +4197,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
 }
 
 template<typename PC>
-static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const PC&& pc) {
+static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const PC&& pc, bool dryrun = false) {
     VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
     if (src1 != nullptr) {
         std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
@@ -4160,7 +4205,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
     if (src2 != nullptr) {
         std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3];
     }
-    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "), " << ggml_op_name(op) << ")");
+    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
+    std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")");
     GGML_ASSERT(op == GGML_OP_GET_ROWS || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type))));  // NOLINT
     GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0));  // NOLINT
     GGML_ASSERT(dst->extra != nullptr);
@@ -4192,20 +4238,18 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
     const uint64_t ned = ned0 * ned1;
 
     vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op);
-    ggml_vk_func_t op_func;
 
     if (pipeline == nullptr) {
-        op_func = ggml_vk_op_get_func(op);
-        if (op_func == nullptr) {
-            std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(op) << " for " << ggml_type_name(src0->type);
-            if (src1 != nullptr) {
-                std::cerr << " and " << ggml_type_name(src1->type);
-            }
-            std::cerr << " to " << ggml_type_name(dst->type) << std::endl;
-            GGML_ABORT("fatal error");
+        std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(op) << " for " << ggml_type_name(src0->type);
+        if (src1 != nullptr) {
+            std::cerr << " and " << ggml_type_name(src1->type);
         }
+        std::cerr << " to " << ggml_type_name(dst->type) << std::endl;
+        GGML_ABORT("fatal error");
+    }
 
-        op_func(ctx, subctx, src0, src1, dst);
+    if (dryrun) {
+        ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
         return;
     }
 
@@ -4294,190 +4338,143 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
     std::array<uint32_t, 3> elements;
 
     // Single call if dimension 2 is contiguous
-    if (op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))) {
-        ggml_pipeline_allocate_descriptor_sets(ctx->device, pipeline, 1);
-
-        switch (op) {
-        case GGML_OP_NORM:
-        case GGML_OP_RMS_NORM:
-        case GGML_OP_SOFT_MAX:
-        case GGML_OP_SUM_ROWS:
-            {
-                const uint32_t nr = ggml_nrows(src0);
-                if (nr > 262144) {
-                    elements = { 512, 512, CEIL_DIV(nr, 262144) };
-                } else if (nr > 512) {
-                    elements = { 512, CEIL_DIV(nr, 512), 1 };
-                } else {
-                    elements = { nr, 1, 1 };
-                }
-            } break;
-        case GGML_OP_GROUP_NORM:
-            {
-                const uint32_t num_groups = dst->op_params[0];
-                elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 };
-            } break;
-        case GGML_OP_DIAG_MASK_INF:
-        case GGML_OP_ROPE:
-            elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
-            break;
-        case GGML_OP_GET_ROWS:
-            elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };
-            break;
-        case GGML_OP_ARGSORT:
-            elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
-            break;
-        case GGML_OP_IM2COL:
-            {
-                const bool is_2D = dst->op_params[6] == 1;
-
-                const uint32_t IC = src1->ne[is_2D ? 2 : 1];
+    GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1))));
 
-                const uint32_t KH = is_2D ? src0->ne[1] : 1;
-                const uint32_t KW =         src0->ne[0];
+    switch (op) {
+    case GGML_OP_NORM:
+    case GGML_OP_RMS_NORM:
+    case GGML_OP_SOFT_MAX:
+    case GGML_OP_SUM_ROWS:
+        {
+            const uint32_t nr = ggml_nrows(src0);
+            if (nr > 262144) {
+                elements = { 512, 512, CEIL_DIV(nr, 262144) };
+            } else if (nr > 512) {
+                elements = { 512, CEIL_DIV(nr, 512), 1 };
+            } else {
+                elements = { nr, 1, 1 };
+            }
+        } break;
+    case GGML_OP_GROUP_NORM:
+        {
+            const uint32_t num_groups = dst->op_params[0];
+            elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 };
+        } break;
+    case GGML_OP_DIAG_MASK_INF:
+    case GGML_OP_ROPE:
+        elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
+        break;
+    case GGML_OP_GET_ROWS:
+        elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };
+        break;
+    case GGML_OP_ARGSORT:
+        elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
+        break;
+    case GGML_OP_IM2COL:
+        {
+            const bool is_2D = dst->op_params[6] == 1;
 
-                const uint32_t OH = is_2D ? dst->ne[2] : 1;
-                const uint32_t OW =         dst->ne[1];
+            const uint32_t IC = src1->ne[is_2D ? 2 : 1];
 
-                const uint32_t batch = src1->ne[3];
+            const uint32_t KH = is_2D ? src0->ne[1] : 1;
+            const uint32_t KW =         src0->ne[0];
 
-                elements = { OW * KW * KH, OH, batch * IC };
-            } break;
-        case GGML_OP_TIMESTEP_EMBEDDING:
-            {
-                const uint32_t dim = dst->op_params[0];
-                uint32_t half_ceil = (dim + 1) / 2;
-                elements = { half_ceil, (uint32_t)src0->ne[0], 1 };
-            } break;
-        case GGML_OP_ADD:
-        case GGML_OP_DIV:
-        case GGML_OP_MUL:
-        case GGML_OP_SCALE:
-        case GGML_OP_SQR:
-        case GGML_OP_SIN:
-        case GGML_OP_COS:
-        case GGML_OP_CLAMP:
-        case GGML_OP_PAD:
-        case GGML_OP_CPY:
-        case GGML_OP_CONCAT:
-        case GGML_OP_UPSCALE:
-        case GGML_OP_UNARY:
-            {
-                const uint32_t ne = ggml_nelements(dst);
-                if (ne > 262144) {
-                    elements = { 512, 512, CEIL_DIV(ne, 262144) };
-                } else if (ne > 512) {
-                    elements = { 512, CEIL_DIV(ne, 512), 1 };
-                } else {
-                    elements = { ne, 1, 1 };
-                }
-            } break;
-        default:
-            elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
-            break;
-        }
+            const uint32_t OH = is_2D ? dst->ne[2] : 1;
+            const uint32_t OW =         dst->ne[1];
 
-        if (!op_supports_incontiguous) {
-            if (x_sz != VK_WHOLE_SIZE) {
-                x_sz *= ne02 * ne03;
-            }
-            if (use_src1 && y_sz != VK_WHOLE_SIZE) {
-                y_sz *= ne12 * ne13;
-            }
-            if (use_src2 && z_sz != VK_WHOLE_SIZE) {
-                z_sz *= ne22 * ne23;
-            }
-            if (d_sz != VK_WHOLE_SIZE) {
-                d_sz *= ned2 * ned3;
-            }
-        }
+            const uint32_t batch = src1->ne[3];
 
-        if (op == GGML_OP_SOFT_MAX) {
-            // Empty src1 is possible in soft_max, but the shader needs a buffer
-            vk_subbuffer subbuf_y;
-            if (use_src1) {
-                subbuf_y = { d_Y, y_buf_offset, y_sz };
+            elements = { OW * KW * KH, OH, batch * IC };
+        } break;
+    case GGML_OP_TIMESTEP_EMBEDDING:
+        {
+            const uint32_t dim = dst->op_params[0];
+            uint32_t half_ceil = (dim + 1) / 2;
+            elements = { half_ceil, (uint32_t)src0->ne[0], 1 };
+        } break;
+    case GGML_OP_ADD:
+    case GGML_OP_DIV:
+    case GGML_OP_MUL:
+    case GGML_OP_SCALE:
+    case GGML_OP_SQR:
+    case GGML_OP_SIN:
+    case GGML_OP_COS:
+    case GGML_OP_CLAMP:
+    case GGML_OP_PAD:
+    case GGML_OP_REPEAT:
+    case GGML_OP_CPY:
+    case GGML_OP_CONCAT:
+    case GGML_OP_UPSCALE:
+    case GGML_OP_UNARY:
+        {
+            const uint32_t ne = ggml_nelements(dst);
+            if (ne > 262144) {
+                elements = { 512, 512, CEIL_DIV(ne, 262144) };
+            } else if (ne > 512) {
+                elements = { 512, CEIL_DIV(ne, 512), 1 };
             } else {
-                subbuf_y = { d_X, 0, x_sz };
+                elements = { ne, 1, 1 };
             }
+        } break;
+    default:
+        elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
+        break;
+    }
 
-            ggml_vk_sync_buffers(subctx);
-            ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
-        } else if (op == GGML_OP_ROPE) {
-            // Empty src2 is possible in rope, but the shader needs a buffer
-            vk_subbuffer subbuf_z;
-            if (use_src2) {
-                subbuf_z = { d_Z, z_buf_offset, z_sz };
-            } else {
-                subbuf_z = { d_X, 0, x_sz };
-            }
+    if (!op_supports_incontiguous) {
+        if (x_sz != VK_WHOLE_SIZE) {
+            x_sz *= ne02 * ne03;
+        }
+        if (use_src1 && y_sz != VK_WHOLE_SIZE) {
+            y_sz *= ne12 * ne13;
+        }
+        if (use_src2 && z_sz != VK_WHOLE_SIZE) {
+            z_sz *= ne22 * ne23;
+        }
+        if (d_sz != VK_WHOLE_SIZE) {
+            d_sz *= ned2 * ned3;
+        }
+    }
 
-            ggml_vk_sync_buffers(subctx);
-            ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
-        } else if (op == GGML_OP_IM2COL) {
-            // im2col uses only src1 and dst buffers
-            ggml_vk_sync_buffers(subctx);
-            ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
-        } else if (use_src2) {
-            ggml_vk_sync_buffers(subctx);
-            ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
-        } else if (use_src1) {
-            ggml_vk_sync_buffers(subctx);
-            ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
+    if (op == GGML_OP_SOFT_MAX) {
+        // Empty src1 is possible in soft_max, but the shader needs a buffer
+        vk_subbuffer subbuf_y;
+        if (use_src1) {
+            subbuf_y = { d_Y, y_buf_offset, y_sz };
         } else {
-            ggml_vk_sync_buffers(subctx);
-            ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
+            subbuf_y = { d_X, 0, x_sz };
         }
-    } else {
-        GGML_ASSERT(op != GGML_OP_SOFT_MAX);
-        GGML_ASSERT(op != GGML_OP_ARGSORT);
-        GGML_ASSERT(!use_src2);
-
-        ggml_pipeline_allocate_descriptor_sets(ctx->device, pipeline, ne02 * ne03);
 
-        switch (op) {
-        case GGML_OP_NORM:
-        case GGML_OP_GROUP_NORM:
-        case GGML_OP_RMS_NORM:
-            elements = { (uint32_t)ne01, 1, 1 };
-            break;
-        case GGML_OP_DIAG_MASK_INF:
-        case GGML_OP_ROPE:
-            elements = { (uint32_t)ne01, (uint32_t)ne00, 1 };
-            break;
-        case GGML_OP_GET_ROWS:
-            elements = {  (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };
-            break;
-        default:
-            elements = { (uint32_t)ne0, 1, 1 };
-            break;
+        ggml_vk_sync_buffers(subctx);
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
+    } else if (op == GGML_OP_ROPE) {
+        // Empty src2 is possible in rope, but the shader needs a buffer
+        vk_subbuffer subbuf_z;
+        if (use_src2) {
+            subbuf_z = { d_Z, z_buf_offset, z_sz };
+        } else {
+            subbuf_z = { d_X, 0, x_sz };
         }
 
-        for (uint64_t i03 = 0; i03 < ne03; i03++) {
-            for (uint64_t i02 = 0; i02 < ne02; i02++) {
-                const uint32_t it_idx0 = (i03 * ne02 + i02);
-                const uint32_t it_idx1 = use_src1 ? ((i03 % ne13) * ne12 + (i02 % ne12)) : 0;
-                const uint32_t x_offset = x_sz * it_idx0;
-                const uint32_t y_offset = y_sz * it_idx1;
-                const uint32_t d_offset = d_sz * it_idx0;
-
-                if (use_src1) {
-                    ggml_vk_sync_buffers(subctx);
-                    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset + x_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset + y_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset + d_offset, d_sz } }, sizeof(PC), &pc, elements);
-                } else {
-                    ggml_vk_sync_buffers(subctx);
-                    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset + x_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset + d_offset, d_sz } }, sizeof(PC), &pc, elements);
-                }
-            }
-        }
+        ggml_vk_sync_buffers(subctx);
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
+    } else if (op == GGML_OP_IM2COL) {
+        // im2col uses only src1 and dst buffers
+        ggml_vk_sync_buffers(subctx);
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
+    } else if (use_src2) {
+        ggml_vk_sync_buffers(subctx);
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
+    } else if (use_src1) {
+        ggml_vk_sync_buffers(subctx);
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
+    } else {
+        ggml_vk_sync_buffers(subctx);
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
     }
 }
 
-static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, {});
-}
-
-static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
     const uint32_t src0_type_size = ggml_type_size(src0->type);
     const uint32_t src1_type_size = ggml_type_size(src1->type);
     const uint32_t dst_type_size = ggml_type_size(dst->type);
@@ -4489,10 +4486,32 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
         (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
         0,
         0.0f, 0.0f, 0,
-    });
+    }, dryrun);
 }
 
-static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
+    ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
+    const uint32_t src0_type_size = ggml_type_size(src0->type);
+    const uint32_t src1_type_size = ggml_type_size(src1->type);
+    const uint32_t dst_type_size = ggml_type_size(dst->type);
+    const uint32_t d_offset = ((extra->offset + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
+
+    int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
+    int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
+    // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
+    int offset = dst->op_params[3] / 4; // offset in bytes
+
+    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ACC, {
+        (uint32_t)ggml_nelements(src0),
+        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
+        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
+        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] /  dst_type_size,
+        d_offset,
+        0.0f, 0.0f, offset,
+    }, dryrun);
+}
+
+static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
     const uint32_t src0_type_size = ggml_type_size(src0->type);
     const uint32_t src1_type_size = ggml_type_size(src1->type);
     const uint32_t dst_type_size = ggml_type_size(dst->type);
@@ -4504,10 +4523,10 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
         (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
         0,
         0.0f, 0.0f, 0,
-    });
+    }, dryrun);
 }
 
-static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
     const uint32_t src0_type_size = ggml_type_size(src0->type);
     const uint32_t src1_type_size = ggml_type_size(src1->type);
     const uint32_t dst_type_size = ggml_type_size(dst->type);
@@ -4519,10 +4538,10 @@ static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const
         (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
         0,
         0.0f, 0.0f, 0,
-    });
+    }, dryrun);
 }
 
-static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
     const uint32_t src0_type_size = ggml_type_size(src0->type);
     const uint32_t src1_type_size = ggml_type_size(src1->type);
     const uint32_t dst_type_size = ggml_type_size(dst->type);
@@ -4534,10 +4553,10 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
         (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
         0,
         0.0f, 0.0f, 0,
-    });
+    }, dryrun);
 }
 
-static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
     int * op_params = (int *)dst->op_params;
 
     const uint32_t src0_type_size = ggml_type_size(src0->type);
@@ -4551,10 +4570,10 @@ static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, co
         (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
         0,
         0.0f, 0.0f, op_params[0],
-    });
+    }, dryrun);
 }
 
-static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
     const uint32_t src0_type_size = ggml_type_size(src0->type);
 
     const float sf0 = (float)dst->ne[0] / src0->ne[0];
@@ -4567,10 +4586,10 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c
         (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
         (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
         sf0, sf1, sf2, sf3,
-    });
+    }, dryrun);
 }
 
-static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
     float * op_params = (float *)dst->op_params;
     const uint32_t src0_type_size = ggml_type_size(src0->type);
     const uint32_t dst_type_size = ggml_type_size(dst->type);
@@ -4581,10 +4600,10 @@ static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, con
         (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
         0,
         op_params[0], 0.0f
-    });
+    }, dryrun);
 }
 
-static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
     const uint32_t src0_type_size = ggml_type_size(src0->type);
     const uint32_t dst_type_size = ggml_type_size(dst->type);
 
@@ -4594,7 +4613,7 @@ static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const
         (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
         0,
         0.0f, 0.0f,
-    });
+    }, dryrun);
 }
 
 static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@@ -4623,7 +4642,7 @@ static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const
     });
 }
 
-static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
     float * op_params = (float *)dst->op_params;
     const uint32_t src0_type_size = ggml_type_size(src0->type);
     const uint32_t dst_type_size = ggml_type_size(dst->type);
@@ -4634,10 +4653,10 @@ static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, con
         (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
         0,
         op_params[0], op_params[1],
-    });
+    }, dryrun);
 }
 
-static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
     const uint32_t src0_type_size = ggml_type_size(src0->type);
     const uint32_t dst_type_size = ggml_type_size(dst->type);
 
@@ -4647,10 +4666,23 @@ static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const
         (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
         0,
         0.0f, 0.0f,
-    });
+    }, dryrun);
 }
 
-static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
+    const uint32_t src0_type_size = ggml_type_size(src0->type);
+    const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+    ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, {
+        (uint32_t)ggml_nelements(dst),
+        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
+        0,
+        0.0f, 0.0f,
+    }, dryrun);
+}
+
+static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
     ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
     const uint32_t src0_type_size = ggml_type_size(src0->type);
     const uint32_t dst_type_size = ggml_type_size(dst->type);
@@ -4662,40 +4694,41 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
         (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
         d_offset,
         0.0f, 0.0f,
-    });
+    }, dryrun);
 }
 
-static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
     float * op_params = (float *)dst->op_params;
 
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
 }
 
-static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
-    int * op_params = (int *)dst->op_params;
+static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
+    const int * int_op_params = (const int *)dst->op_params;
+    const float * float_op_params = (const float *)dst->op_params;
 
-    uint32_t num_groups = op_params[0];
-    uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
-    static const float eps = 1e-6f;
+    const uint32_t num_groups = int_op_params[0];
+    const float eps = float_op_params[1];
+    const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
 
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f });
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
 }
 
-static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
     float * op_params = (float *)dst->op_params;
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
 }
 
-static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
+static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
 }
 
-static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
     int32_t * op_params = (int32_t *)dst->op_params;
-    ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] });
+    ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
 }
 
-static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
     float * op_params = (float *)dst->op_params;
 
     float scale = op_params[0];
@@ -4717,10 +4750,10 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
         scale, max_bias,
         m0, m1,
         n_head_log2,
-    });
+    }, dryrun);
 }
 
-static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
+static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
     const int n_dims        = ((int32_t *) dst->op_params)[1];
     // const int mode          = ((int32_t *) dst->op_params)[2];
     // const int n_ctx         = ((int32_t *) dst->op_params)[3];
@@ -4741,10 +4774,10 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
         (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
         freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
         src2 != nullptr,
-    });
+    }, dryrun);
 }
 
-static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
     int32_t * op_params = (int32_t *)dst->op_params;
 
     uint32_t ncols = src0->ne[0];
@@ -4760,14 +4793,14 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
         ncols,
         ncols_pad,
         op_params[0],
-    });
+    }, dryrun);
 }
 
-static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f });
+static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
 }
 
-static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
     const int32_t s0 = dst->op_params[0];
     const int32_t s1 = dst->op_params[1];
     const int32_t p0 = dst->op_params[2];
@@ -4798,22 +4831,22 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
         pelements,
         IC * KH * KW,
         s0, s1, p0, p1, d0, d1,
-    });
+    }, dryrun);
 }
 
-static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
     const uint32_t dim = dst->op_params[0];
     const uint32_t max_period = dst->op_params[1];
     const uint32_t nb1 = dst->nb[1] / ggml_type_size(dst->type);
 
     ggml_vk_op_f32<vk_op_timestep_embedding_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, {
         nb1, dim, max_period,
-    });
+    }, dryrun);
 }
 
-static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
     const float * op_params = (const float *)dst->op_params;
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f });
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun);
 }
 
 #ifdef GGML_VULKAN_RUN_TESTS
@@ -4959,9 +4992,9 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
         }
     }
 
-    ggml_pipeline_allocate_descriptor_sets(ctx->device, p, num_it);
+    ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it);
     if (split_k > 1) {
-        ggml_pipeline_allocate_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it);
+        ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it);
 
         if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
             // Resize buffer
@@ -5208,7 +5241,7 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
     ggml_vk_quantize_data(x, qx, ne, quant);
     ggml_vk_dequantize_data(qx, x_ref, ne, quant);
 
-    ggml_pipeline_allocate_descriptor_sets(ctx->device, p, 1);
+    ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
 
     ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
 
@@ -5329,9 +5362,9 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
         y[i] = (i % k == i / k) ? 1.0f : 0.0f;
     }
 
-    ggml_pipeline_allocate_descriptor_sets(ctx->device, p, num_it);
+    ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it);
     if (split_k > 1) {
-        ggml_pipeline_allocate_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it);
+        ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it);
 
         if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
             // Resize buffer
@@ -5459,137 +5492,6 @@ static ggml_tensor_extra_gpu * ggml_vk_tensor_create_extra(ggml_tensor * tensor)
     return extra;
 }
 
-static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggml_tensor * node){
-    VK_LOG_DEBUG("ggml_vk_preallocate_buffers_graph(" << node << ")");
-    ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) node->extra;
-
-    if (extra == nullptr) {
-        return;
-    }
-
-    ggml_tensor * src0 = node->src[0];
-    ggml_tensor * src1 = node->src[1];
-
-    const bool use_src0 = src0 != nullptr;
-    const int64_t ne00 = use_src0 ? src0->ne[0] : 0;
-    const int64_t ne01 = use_src0 ? src0->ne[1] : 0;
-    const int64_t ne02 = use_src0 ? src0->ne[2] : 0;
-    const int64_t ne03 = use_src0 ? src0->ne[3] : 0;
-    const bool use_src1 = src1 != nullptr && node->op != GGML_OP_CPY && node->op != GGML_OP_CONT && node->op != GGML_OP_DUP;
-    const int64_t ne10 = use_src1 ? src1->ne[0] : 0;
-    const int64_t ne11 = use_src1 ? src1->ne[1] : 0;
-    const int64_t ne12 = use_src1 ? src1->ne[2] : 0;
-    const int64_t ne13 = use_src1 ? src1->ne[3] : 0;
-    const int64_t ne20 = node->ne[0];
-    const int64_t ne21 = node->ne[1];
-    const int64_t ne22 = node->ne[2];
-    const int64_t ne23 = node->ne[3];
-
-    const ggml_type src0_type = (use_src0 && src0->type == GGML_TYPE_F32) ? src0->type : GGML_TYPE_F16;
-    const ggml_type src1_type = (use_src1 && src1->type == GGML_TYPE_F32) ? src1->type : GGML_TYPE_F16;
-
-    const bool x_non_contig = use_src0 && !ggml_vk_dim01_contiguous(src0);
-    const bool y_non_contig = use_src1 && !ggml_vk_dim01_contiguous(src1);
-
-    const bool y_f32_kernel = use_src1 && src1->type == GGML_TYPE_F32 && !y_non_contig;
-
-    bool mmp = (use_src0 && use_src1 && (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID)) ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type) != nullptr : false;
-
-    const bool qx_needs_dequant = use_src0 && (!mmp || x_non_contig);
-    const bool qy_needs_dequant = use_src1 && ((src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig);
-
-    int split_k;
-    if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
-        split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
-    } else {
-        split_k = 1;
-    }
-    const uint32_t x_ne = ne00 * ne01;
-    const uint32_t y_ne = ne10 * ne11;
-    const uint32_t d_ne = ne20 * ne21;
-
-    const uint64_t x_sz = (use_src0 && qx_needs_dequant) ? ggml_vk_align_size(sizeof(src0_type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne02 * ne03 : 0;
-    const uint64_t y_sz = (use_src1 && qy_needs_dequant) ? ggml_vk_align_size(sizeof(src1_type) * y_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne12 * ne13 : 0;
-    uint64_t d_sz = ggml_vk_align_size(ggml_type_size(node->type) * d_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne22 * ne23;
-    const uint64_t split_k_size = split_k > 1 ? d_sz * 4 : 0;
-
-    if (extra->buffer_gpu.expired()) {
-        // Workaround for CPU backend BLAS matmul calls
-        extra->buffer_gpu = ggml_vk_create_buffer_temp(ctx, d_sz);
-    }
-
-    switch (node->op) {
-    case GGML_OP_REPEAT:
-    case GGML_OP_GET_ROWS:
-    case GGML_OP_RESHAPE:
-    case GGML_OP_VIEW:
-    case GGML_OP_PERMUTE:
-    case GGML_OP_TRANSPOSE:
-    case GGML_OP_ADD:
-    case GGML_OP_SCALE:
-    case GGML_OP_SQR:
-    case GGML_OP_SIN:
-    case GGML_OP_COS:
-    case GGML_OP_CLAMP:
-    case GGML_OP_PAD:
-    case GGML_OP_CPY:
-    case GGML_OP_CONT:
-    case GGML_OP_DUP:
-    case GGML_OP_MUL:
-    case GGML_OP_DIV:
-    case GGML_OP_CONCAT:
-    case GGML_OP_UPSCALE:
-    case GGML_OP_NORM:
-    case GGML_OP_GROUP_NORM:
-    case GGML_OP_RMS_NORM:
-    case GGML_OP_DIAG_MASK_INF:
-    case GGML_OP_SOFT_MAX:
-    case GGML_OP_ROPE:
-    case GGML_OP_ARGSORT:
-    case GGML_OP_SUM_ROWS:
-    case GGML_OP_IM2COL:
-    case GGML_OP_TIMESTEP_EMBEDDING:
-    case GGML_OP_LEAKY_RELU:
-        break;
-    case GGML_OP_UNARY:
-        switch (ggml_get_unary_op(node)) {
-        case GGML_UNARY_OP_SILU:
-        case GGML_UNARY_OP_GELU:
-        case GGML_UNARY_OP_GELU_QUICK:
-        case GGML_UNARY_OP_RELU:
-        case GGML_UNARY_OP_TANH:
-            break;
-        default:
-            return;
-        }
-        break;
-    case GGML_OP_MUL_MAT:
-    case GGML_OP_MUL_MAT_ID:
-        if (
-                x_sz > ctx->device->max_memory_allocation_size ||
-                y_sz > ctx->device->max_memory_allocation_size ||
-                d_sz > ctx->device->max_memory_allocation_size ||
-                split_k_size > ctx->device->max_memory_allocation_size) {
-            GGML_ABORT("Requested preallocation size is too large");
-        }
-        if (ctx->prealloc_size_x < x_sz) {
-            ctx->prealloc_size_x = x_sz;
-        }
-        if (ctx->prealloc_size_y < y_sz) {
-            ctx->prealloc_size_y = y_sz;
-        }
-        if (ctx->prealloc_size_split_k < split_k_size) {
-            ctx->prealloc_size_split_k = split_k_size;
-        }
-        if (ctx->staging_size < x_sz + y_sz) {
-            ctx->staging_size = x_sz + y_sz;
-        }
-        break;
-    default:
-        return;
-    }
-}
-
 static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
 #if defined(GGML_VULKAN_RUN_TESTS)
     ctx->staging = ggml_vk_create_buffer_check(ctx->device, 100ul * 1024ul * 1024ul,
@@ -5754,19 +5656,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
         }
         ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k);
     }
-    if (ctx->staging == nullptr || (ctx->staging_size > 0 && ctx->staging->size < ctx->staging_size)) {
-        VK_LOG_MEMORY("ggml_vk_preallocate_buffers(staging_size: " << ctx->staging_size << ")");
-        // Resize buffer
-        if (ctx->staging != nullptr) {
-            ggml_vk_destroy_buffer(ctx->staging);
-        }
-        ctx->staging = ggml_vk_create_buffer_check(ctx->device, ctx->staging_size,
-            vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
-            vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
-    }
 }
 
-static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, bool last_node){
+static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, bool last_node, bool dryrun){
     ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) node->extra;
 
     if (ggml_is_empty(node) || extra == nullptr) {
@@ -5775,7 +5667,6 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
 
     VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")");
     ctx->semaphore_idx = 0;
-    ctx->staging_offset = 0;
 
     const ggml_tensor * src0 = node->src[0];
     const ggml_tensor * src1 = node->src[1];
@@ -5804,6 +5695,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
     case GGML_OP_REPEAT:
     case GGML_OP_GET_ROWS:
     case GGML_OP_ADD:
+    case GGML_OP_ACC:
     case GGML_OP_MUL:
     case GGML_OP_DIV:
     case GGML_OP_CONCAT:
@@ -5839,49 +5731,55 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
 
     vk_context compute_ctx;
 
-    if (ctx->compute_ctx.expired()) {
-        compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
-        ctx->compute_ctx = compute_ctx;
-        ggml_vk_ctx_begin(ctx->device, compute_ctx);
-    } else {
-        compute_ctx = ctx->compute_ctx.lock();
+    if (!dryrun) {
+        if (ctx->compute_ctx.expired()) {
+            compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
+            ctx->compute_ctx = compute_ctx;
+            ggml_vk_ctx_begin(ctx->device, compute_ctx);
+        } else {
+            compute_ctx = ctx->compute_ctx.lock();
+        }
     }
 
     switch (node->op) {
     case GGML_OP_REPEAT:
-        ggml_vk_repeat(ctx, compute_ctx, src0, node);
+        ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun);
+
+        break;
+    case GGML_OP_ACC:
+        ggml_vk_acc(ctx, compute_ctx, src0, src1, node, dryrun);
 
         break;
     case GGML_OP_GET_ROWS:
-        ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node);
+        ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node, dryrun);
 
         break;
     case GGML_OP_ADD:
-        ggml_vk_add(ctx, compute_ctx, src0, src1, node);
+        ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);
 
         break;
     case GGML_OP_MUL:
-        ggml_vk_mul(ctx, compute_ctx, src0, src1, node);
+        ggml_vk_mul(ctx, compute_ctx, src0, src1, node, dryrun);
 
         break;
     case GGML_OP_DIV:
-        ggml_vk_div(ctx, compute_ctx, src0, src1, node);
+        ggml_vk_div(ctx, compute_ctx, src0, src1, node, dryrun);
 
         break;
     case GGML_OP_CONCAT:
-        ggml_vk_concat(ctx, compute_ctx, src0, src1, node);
+        ggml_vk_concat(ctx, compute_ctx, src0, src1, node, dryrun);
 
         break;
     case GGML_OP_UPSCALE:
-        ggml_vk_upscale(ctx, compute_ctx, src0, node);
+        ggml_vk_upscale(ctx, compute_ctx, src0, node, dryrun);
 
         break;
     case GGML_OP_SCALE:
-        ggml_vk_scale(ctx, compute_ctx, src0, node);
+        ggml_vk_scale(ctx, compute_ctx, src0, node, dryrun);
 
         break;
     case GGML_OP_SQR:
-        ggml_vk_sqr(ctx, compute_ctx, src0, node);
+        ggml_vk_sqr(ctx, compute_ctx, src0, node, dryrun);
 
         break;
     case GGML_OP_SIN:
@@ -5893,29 +5791,29 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
 
         break;
     case GGML_OP_CLAMP:
-        ggml_vk_clamp(ctx, compute_ctx, src0, node);
+        ggml_vk_clamp(ctx, compute_ctx, src0, node, dryrun);
 
         break;
     case GGML_OP_PAD:
-        ggml_vk_pad(ctx, compute_ctx, src0, node);
+        ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun);
 
         break;
     case GGML_OP_CPY:
     case GGML_OP_CONT:
     case GGML_OP_DUP:
-        ggml_vk_cpy(ctx, compute_ctx, src0, node);
+        ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
 
         break;
     case GGML_OP_NORM:
-        ggml_vk_norm(ctx, compute_ctx, src0, node);
+        ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun);
 
         break;
     case GGML_OP_GROUP_NORM:
-        ggml_vk_group_norm(ctx, compute_ctx, src0, node);
+        ggml_vk_group_norm(ctx, compute_ctx, src0, node, dryrun);
 
         break;
     case GGML_OP_RMS_NORM:
-        ggml_vk_rms_norm(ctx, compute_ctx, src0, node);
+        ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
 
         break;
     case GGML_OP_UNARY:
@@ -5925,59 +5823,63 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
         case GGML_UNARY_OP_GELU_QUICK:
         case GGML_UNARY_OP_RELU:
         case GGML_UNARY_OP_TANH:
-            ggml_vk_unary(ctx, compute_ctx, src0, node);
+            ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun);
             break;
         default:
             return;
         }
         break;
     case GGML_OP_DIAG_MASK_INF:
-        ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node);
+        ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun);
 
         break;
     case GGML_OP_SOFT_MAX:
-        ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node);
+        ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun);
 
         break;
     case GGML_OP_ROPE:
-        ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node);
+        ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, dryrun);
 
         break;
     case GGML_OP_ARGSORT:
-        ggml_vk_argsort(ctx, compute_ctx, src0, node);
+        ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
 
         break;
     case GGML_OP_SUM_ROWS:
-        ggml_vk_sum_rows(ctx, compute_ctx, src0, node);
+        ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun);
 
         break;
     case GGML_OP_IM2COL:
-        ggml_vk_im2col(ctx, compute_ctx, src0, src1, node);
+        ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun);
 
         break;
     case GGML_OP_TIMESTEP_EMBEDDING:
-        ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node);
+        ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun);
 
         break;
     case GGML_OP_LEAKY_RELU:
-        ggml_vk_leaky_relu(ctx, compute_ctx, src0, node);
+        ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun);
 
         break;
     case GGML_OP_MUL_MAT:
-        ggml_vk_mul_mat(ctx, compute_ctx, src0, src1, node);
+        ggml_vk_mul_mat(ctx, compute_ctx, src0, src1, node, dryrun);
 
         break;
     case GGML_OP_MUL_MAT_ID:
-        ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node);
+        ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node, dryrun);
 
         break;
     default:
         return;
     }
 
+    if (dryrun) {
+        return;
+    }
+
     ctx->tensor_ctxs[node_idx] = compute_ctx;
 
-#ifdef GGML_VULKAN_CHECK_RESULTS
+#if defined(GGML_VULKAN_CHECK_RESULTS) || defined(GGML_VULKAN_PERF)
     // Force context reset on each node so that each tensor ends up in its own context
     // and can be run and compared to its CPU equivalent separately
     last_node = true;
@@ -5995,6 +5897,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
 
     switch (tensor->op) {
     case GGML_OP_ADD:
+    case GGML_OP_ACC:
     case GGML_OP_GET_ROWS:
     case GGML_OP_MUL:
     case GGML_OP_DIV:
@@ -6063,6 +5966,10 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
 
     vk_context subctx = ctx->tensor_ctxs[tensor_idx].lock();
 
+#ifdef GGML_VULKAN_PERF
+    std::chrono::steady_clock::time_point start;
+#endif // GGML_VULKAN_PERF
+
     // Only run if ctx hasn't been submitted yet
     if (!subctx->seqs.empty()) {
         // Do staging buffer copies
@@ -6070,11 +5977,21 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
             memcpy(cpy.dst, cpy.src, cpy.n);
         }
 
+#ifdef GGML_VULKAN_PERF
+        start = std::chrono::steady_clock::now();
+#endif // GGML_VULKAN_PERF
+
         ggml_vk_submit(subctx, ctx->fence);
     }
 
     if (tensor_idx == subctx->exit_tensor_idx) {
         VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences");
+
+#ifdef GGML_VULKAN_PERF
+        auto duration = std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::steady_clock::now() - start);
+        ctx->device->perf_logger->log_timing(tensor, duration.count());
+#endif // GGML_VULKAN_PERF
+
         ctx->device->device.resetFences({ ctx->fence });
 
         // Do staging buffer copies
@@ -6096,12 +6013,14 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
     }
     ctx->gc.temp_buffers.clear();
 
-    for (auto& pipeline : ctx->device->pipelines) {
-        if (pipeline.expired()) {
+    for (auto& dsr : ctx->device->pipeline_descriptor_set_requirements) {
+        vk_pipeline_ref plr = ctx->device->pipelines[dsr.first];
+
+        if (plr.expired()) {
             continue;
         }
 
-        vk_pipeline pl = pipeline.lock();
+        vk_pipeline pl = plr.lock();
         ggml_pipeline_cleanup(pl);
     }
 
@@ -6125,10 +6044,9 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
         ctx->device->device.resetEvent(event);
     }
 
-    ctx->staging_offset = 0;
-
     ctx->tensor_ctxs.clear();
     ctx->gc.contexts.clear();
+    ctx->device->pipeline_descriptor_set_requirements.clear();
 }
 
 // Clean up on backend free
@@ -6139,7 +6057,6 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
     ggml_vk_destroy_buffer(ctx->prealloc_x);
     ggml_vk_destroy_buffer(ctx->prealloc_y);
     ggml_vk_destroy_buffer(ctx->prealloc_split_k);
-    ggml_vk_destroy_buffer(ctx->staging);
 
     for (auto& buffer : ctx->buffer_pool) {
         ggml_vk_destroy_buffer(buffer);
@@ -6148,7 +6065,6 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
     ctx->prealloc_size_x = 0;
     ctx->prealloc_size_y = 0;
     ctx->prealloc_size_split_k = 0;
-    ctx->staging_size = 0;
 
     for (auto& event : ctx->gc.events) {
         ctx->device->device.destroyEvent(event);
@@ -6477,7 +6393,7 @@ GGML_CALL static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, g
 
     vk_buffer buf = extra->buffer_gpu.lock();
 
-    ggml_vk_buffer_write_async(transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size, ctx->staging, ctx->staging_offset);
+    ggml_vk_buffer_write_async(transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size);
 }
 
 GGML_CALL static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
@@ -6500,7 +6416,7 @@ GGML_CALL static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, c
 
     vk_buffer buf = extra->buffer_gpu.lock();
 
-    ggml_vk_buffer_read_async(transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size, ctx->staging, ctx->staging_offset);
+    ggml_vk_buffer_read_async(transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size);
 }
 
 GGML_CALL static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) {
@@ -6566,9 +6482,10 @@ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backen
     ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
 
     for (int i = 0; i < cgraph->n_nodes; i++) {
-        ggml_vk_preallocate_buffers_graph(ctx, cgraph->nodes[i]);
+        ggml_vk_build_graph(ctx, cgraph->nodes[i], i, 0, true);
     }
     ggml_vk_preallocate_buffers(ctx);
+    ggml_pipeline_allocate_descriptor_sets(ctx->device);
 
     int last_node = cgraph->n_nodes - 1;
 
@@ -6581,7 +6498,7 @@ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backen
     ctx->tensor_ctxs.resize(cgraph->n_nodes);
 
     for (int i = 0; i < cgraph->n_nodes; i++) {
-        ggml_vk_build_graph(ctx, cgraph->nodes[i], i, i == last_node);
+        ggml_vk_build_graph(ctx, cgraph->nodes[i], i, i == last_node, false);
     }
 
     for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -6607,6 +6524,10 @@ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backen
         GGML_ASSERT(ok);
     }
 
+#ifdef GGML_VULKAN_PERF
+    ctx->device->perf_logger->print_timings();
+#endif
+
     ggml_vk_graph_cleanup(ctx);
 
     return GGML_STATUS_SUCCESS;
@@ -6698,10 +6619,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
                 return false;
             } break;
         case GGML_OP_REPEAT:
-            {
-                ggml_type src0_type = op->src[0]->type;
-                return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
-            } break;
+            return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
         case GGML_OP_ROPE:
             return ggml_is_contiguous(op->src[0]);
         case GGML_OP_NONE:
@@ -6713,6 +6631,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
         case GGML_OP_GROUP_NORM:
         case GGML_OP_RMS_NORM:
         case GGML_OP_ADD:
+        case GGML_OP_ACC:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
         case GGML_OP_CONCAT:
@@ -7163,16 +7082,24 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
         tensor_clone = ggml_scale(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0]);
     } else if (tensor->op == GGML_OP_SQR) {
         tensor_clone = ggml_sqr(ggml_ctx, src0_clone);
+    } else if (tensor->op == GGML_OP_SIN) {
+        tensor_clone = ggml_sin(ggml_ctx, src0_clone);
+    } else if (tensor->op == GGML_OP_COS) {
+        tensor_clone = ggml_cos(ggml_ctx, src0_clone);
     } else if (tensor->op == GGML_OP_CLAMP) {
         tensor_clone = ggml_clamp(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
     } else if (tensor->op == GGML_OP_PAD) {
         tensor_clone = ggml_pad(ggml_ctx, src0_clone, tensor->ne[0] - src0_clone->ne[0], tensor->ne[1] - src0_clone->ne[1], tensor->ne[2] - src0_clone->ne[2], tensor->ne[3] - src0_clone->ne[3]);
+    } else if (tensor->op == GGML_OP_REPEAT) {
+        tensor_clone = ggml_repeat(ggml_ctx, src0_clone, src1_clone);
     } else if (tensor->op == GGML_OP_ADD) {
         tensor_clone = ggml_add(ggml_ctx, src0_clone, src1_clone);
+    } else if (tensor->op == GGML_OP_ACC) {
+        tensor_clone = ggml_acc(ggml_ctx, src0_clone, src1_clone, tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
     } else if (tensor->op == GGML_OP_NORM) {
         tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
     } else if (tensor->op == GGML_OP_GROUP_NORM) {
-        tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params);
+        tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
     } else if (tensor->op == GGML_OP_RMS_NORM) {
         tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
     } else if (tensor->op == GGML_OP_SOFT_MAX) {
diff --git a/src/vulkan-shaders/acc.comp b/src/vulkan-shaders/acc.comp
new file mode 100644 (file)
index 0000000..4c8739e
--- /dev/null
@@ -0,0 +1,24 @@
+#version 450
+
+#include "types.comp"
+#include "generic_binary_head.comp"
+
+void main() {
+    const uint idx = gl_GlobalInvocationID.x;
+    if (idx >= p.ne) {
+        return;
+    }
+
+    const uint offset = p.param3;
+    const uint src1_i = idx - offset;
+    const uint oz = src1_i / p.nb02;
+    const uint oy = (src1_i - (oz * p.nb02)) / p.nb01;
+    const uint ox = src1_i % p.nb01;
+
+    if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
+        data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
+    } else {
+        data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]));
+    }
+}
+
index 08ab5514bfb498173e9c6ab97c7fd5b7138aa9af..c23b6eb1b0cd50ba29b03cdd7dadff746c7e4321 100644 (file)
@@ -30,6 +30,10 @@ void main() {
 #ifndef OPTIMIZATION_ERROR_WORKAROUND
     data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : data_b[src1_idx]);
 #else
-    data_d[p.d_offset + dst_idx] = is_src0 ? data_a[src0_idx] : data_b[src1_idx];
+    if (is_src0) {
+        data_d[p.d_offset + dst_idx] = data_a[src0_idx];
+    } else {
+        data_d[p.d_offset + dst_idx] = data_b[src1_idx];
+    }
 #endif
 }
index 46a6369bcfd201ec582f6802ba7923f6429c7f7e..d3ccba7fcb3fda747545cea6db04890f228f1b98 100644 (file)
@@ -39,8 +39,7 @@ void main() {
         vec2 v = dequantize(ib, iqs, a_offset / QUANT_K);
 
         // matrix multiplication
-        tmp[tid] += FLOAT_TYPE(v.x) * FLOAT_TYPE(data_b[b_offset + iybs + iqs]) +
-                    FLOAT_TYPE(v.y) * FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]);
+        tmp[tid] = fma(FLOAT_TYPE(v.x), FLOAT_TYPE(data_b[b_offset + iybs + iqs]), fma(FLOAT_TYPE(v.y), FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]), tmp[tid]));
     }
 
     // sum up partial sums and write back result
index cb3f3c0df0801ba0952ea9efe39b5f2c84b99471..1cc4996d393a2a33d67cc0be1b549ca72c721876 100644 (file)
@@ -53,7 +53,7 @@ void main() {
 
         const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
 
-        tmp[tid] += xi * FLOAT_TYPE(data_b[iy]);
+        tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]);
     }
 
     // sum up partial sums and write back result
index 4b1871caaf4c2ddcb083bffca087904369164702..9b443807d87817c170d64f03ffa4bba406894972 100644 (file)
@@ -52,7 +52,7 @@ void main() {
         // y is not transposed but permuted
         const uint iy = channel*nrows_y + row_y;
 
-        tmp[tid] += xi * FLOAT_TYPE(data_b[iy]);
+        tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]);
     }
 
     // dst is not transposed and not permuted
index 4cd97799df19631e213b7302d6105f9b662b3f02..ec8eadcd5828a426ee247c4be515257cdbaf921f 100644 (file)
@@ -39,24 +39,25 @@ void main() {
         FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
         FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
         for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
-            sum1 += FLOAT_TYPE(data_b[b_offset + y_idx + l +  0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 0) & 3)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 0) & 3)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 2) & 3)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 2) & 3)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 4) & 3)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 4) & 3)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 6) & 3)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 6) & 3);
-            sum2 += FLOAT_TYPE(data_b[b_offset + y_idx + l +  0]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 0] >> 4) & 0xF)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 1] >> 4) & 0xF)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 2] >> 4) & 0xF)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 3] >> 4) & 0xF)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 4] >> 4) & 0xF)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 5] >> 4) & 0xF)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 6] >> 4) & 0xF)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 7] >> 4) & 0xF);
+            sum1 = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +  0]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 0) & 3),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 0) & 3),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 2) & 3),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 2) & 3),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 4) & 3),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 4) & 3),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 6) & 3),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +112]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 6) & 3), sum1))))))));
+            sum2 = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +  0]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 0] >> 4) & 0xF),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 1] >> 4) & 0xF),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 2] >> 4) & 0xF),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 3] >> 4) & 0xF),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 4] >> 4) & 0xF),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 5] >> 4) & 0xF),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 6] >> 4) & 0xF),
+                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +112]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 7] >> 4) & 0xF), sum2))))))));
         }
-        tmp[16 * ix + tid] += dall * sum1 - dmin * sum2;
+        const uint tmp_idx = 16 * ix + tid;
+        tmp[tmp_idx] = fma(dall, sum1, fma(-dmin, sum2, tmp[tmp_idx]));
     }
 
     // sum up partial sums and write back result
index a6e430ea08c5d3c7de4fa46e71cdaadc079a0778..3ca4ad85a5ca0ee6bbff92b6dce39af83ca5a454 100644 (file)
@@ -40,16 +40,17 @@ void main() {
 
         FLOAT_TYPE sum = FLOAT_TYPE(0.0);
         for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
-            sum += FLOAT_TYPE(data_b[b_offset + y_idx + l +  0]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[0] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ]     ) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 0)) != 0) ? 0 : 4))
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[2] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 1)) != 0) ? 0 : 4))
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[4] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 2)) != 0) ? 0 : 4))
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[6] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 3)) != 0) ? 0 : 4))
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[1] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16]     ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4))
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[3] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4))
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[5] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4))
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[7] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4));
+            sum = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +  0]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[0] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ]     ) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 0)) != 0) ? 0 : 4)),
+                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[2] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 1)) != 0) ? 0 : 4)),
+                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[4] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 2)) != 0) ? 0 : 4)),
+                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[6] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 3)) != 0) ? 0 : 4)),
+                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[1] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16]     ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)),
+                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[3] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)),
+                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[5] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)),
+                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[7] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum))))))));
         }
-        tmp[16 * ix + tid] += d * sum;
+        const uint tmp_idx = 16 * ix + tid;
+        tmp[tmp_idx] = fma(d, sum, tmp[tmp_idx]);
     }
 
     // sum up partial sums and write back result
index 75569363c64a92d48e9063137dd2f68ca79bea70..d91e00e10061a29ae707fe2e51c2342f6d046404 100644 (file)
@@ -67,17 +67,17 @@ void main() {
         const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 66]  >> 4);
         const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 67]  >> 4);
 
-        const FLOAT_TYPE sx = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx]) * q4_0 + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1 + FLOAT_TYPE(data_b[b_offset + y1_idx + 2]) * q4_2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * q4_3);
-        const FLOAT_TYPE sy = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * q4_4 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_5 + FLOAT_TYPE(data_b[b_offset + y1_idx + 34]) * q4_6 + FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * q4_7);
-        const FLOAT_TYPE sz = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx]) * q4_8 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_9 + FLOAT_TYPE(data_b[b_offset + y2_idx + 2]) * q4_10 + FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * q4_11);
-        const FLOAT_TYPE sw = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * q4_12 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_13 + FLOAT_TYPE(data_b[b_offset + y2_idx + 34]) * q4_14 + FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * q4_15);
-        const FLOAT_TYPE smin = FLOAT_TYPE(
-            FLOAT_TYPE(data_b[b_offset + y1_idx    ]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx    ]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * sc7
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 2]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 34]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 2]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 34]) * sc7
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * sc7
-        );
-        tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin);
+        const FLOAT_TYPE sx = fma(FLOAT_TYPE(data_b[b_offset + y1_idx]),      q4_0,  fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]),  q4_1,  fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 2]),  q4_2,  FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) *  q4_3)));
+        const FLOAT_TYPE sy = fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), q4_4,  fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), q4_5,  fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 34]), q4_6,  FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * q4_7)));
+        const FLOAT_TYPE sz = fma(FLOAT_TYPE(data_b[b_offset + y2_idx]),      q4_8,  fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]),  q4_9,  fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 2]),  q4_10, FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) *  q4_11)));
+        const FLOAT_TYPE sw = fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), q4_12, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), q4_13, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 34]), q4_14, FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * q4_15)));
+        const FLOAT_TYPE smin =
+            fma(FLOAT_TYPE(data_b[b_offset + y1_idx    ]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx    ]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), sc7,
+            fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), sc7,
+            fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 2]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 34]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 2]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 34]), sc7,
+            fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 3]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 35]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 3]), sc6,     FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * sc7)))))))))))))));
+        const uint tmp_idx = 16 * ix + tid;
+        tmp[tmp_idx] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, tmp[tmp_idx]));
 #else
         const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset     ] & 0xf);
         const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset +  1] & 0xf);
@@ -88,16 +88,19 @@ void main() {
         const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 64]  >> 4);
         const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 65]  >> 4);
 
-        const FLOAT_TYPE sx = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx     ]) * q4_0  + FLOAT_TYPE(data_b[b_offset + y1_idx +  1]) * q4_1);
-        const FLOAT_TYPE sy = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * q4_2  + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_3);
-        const FLOAT_TYPE sz = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx     ]) * q4_4  + FLOAT_TYPE(data_b[b_offset + y2_idx +  1]) * q4_5);
-        const FLOAT_TYPE sw = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * q4_6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_7);
-        const FLOAT_TYPE smin = FLOAT_TYPE(
-            FLOAT_TYPE(data_b[b_offset + y1_idx]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * sc7
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7
-        );
-
-        tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) + sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin);
+        const FLOAT_TYPE sx = fma(FLOAT_TYPE(data_b[b_offset + y1_idx     ]), q4_0, FLOAT_TYPE(data_b[b_offset + y1_idx +  1]) * q4_1);
+        const FLOAT_TYPE sy = fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), q4_2, FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_3);
+        const FLOAT_TYPE sz = fma(FLOAT_TYPE(data_b[b_offset + y2_idx     ]), q4_4, FLOAT_TYPE(data_b[b_offset + y2_idx +  1]) * q4_5);
+        const FLOAT_TYPE sw = fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), q4_6, FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_7);
+        const FLOAT_TYPE smin =
+            fma(FLOAT_TYPE(data_b[b_offset + y1_idx    ]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx    ]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), sc7,
+          + fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), sc6, FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7)))))));
+
+        tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) +
+                        sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin);
+        const uint tmp_idx = 16 * ix + tid;
+        tmp[tmp_idx] = fma(dall, (fma(sx, FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f), fma(sy, FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f),
+                       fma(sz, FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)), fma(sw, FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))))))), fma(-dmin, smin, tmp[tmp_idx]));
 #endif
     }
 
index 9be3645bdea0eadd52f2fdedb99a9ca2b1a9ac07..2306785af4226159d016ea858f72385f14540fca 100644 (file)
@@ -66,35 +66,33 @@ void main() {
         const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 80]  >> 4);
         const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 81]  >> 4);
 
-        const FLOAT_TYPE sx = FLOAT_TYPE(
-            FLOAT_TYPE(data_b[b_offset + y1_idx     ]) * (q4_0 + (((data_a[ib0 + i].qh[l0     ] & hm1) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y1_idx +  1]) * (q4_1 + (((data_a[ib0 + i].qh[l0 +  1] & hm1) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) * (q4_2 + (((data_a[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 17]) * (q4_3 + (((data_a[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0))
-        );
-        const FLOAT_TYPE sy = FLOAT_TYPE(
-            FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * (q4_4 + (((data_a[ib0 + i].qh[l0     ] & (hm1 << 1)) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * (q4_5 + (((data_a[ib0 + i].qh[l0 +  1] & (hm1 << 1)) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) * (q4_6 + (((data_a[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 49]) * (q4_7 + (((data_a[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0))
-        );
-        const FLOAT_TYPE sz = FLOAT_TYPE(
-            FLOAT_TYPE(data_b[b_offset + y2_idx     ]) * (q4_8  + (((data_a[ib0 + i].qh[l0     ] & hm2) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y2_idx +  1]) * (q4_9  + (((data_a[ib0 + i].qh[l0 +  1] & hm2) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) * (q4_10 + (((data_a[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y2_idx + 17]) * (q4_11 + (((data_a[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0))
-        );
-        const FLOAT_TYPE sw = FLOAT_TYPE(
-            FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * (q4_12 + (((data_a[ib0 + i].qh[l0     ] & (hm2 << 1)) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * (q4_13 + (((data_a[ib0 + i].qh[l0 +  1] & (hm2 << 1)) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) * (q4_14 + (((data_a[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y2_idx + 49]) * (q4_15 + (((data_a[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0))
-        );
-        const FLOAT_TYPE smin = FLOAT_TYPE(
-            (FLOAT_TYPE(data_b[b_offset + y1_idx]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 17])) * sc2 + (FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 49])) * sc3
-          + (FLOAT_TYPE(data_b[b_offset + y2_idx]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 17])) * sc6 + (FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 49])) * sc7
-        );
-        tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin);
+        const FLOAT_TYPE sx =
+          fma(FLOAT_TYPE(data_b[b_offset + y1_idx     ]), (q4_0 + (((data_a[ib0 + i].qh[l0     ] & hm1) != 0) ? 16 : 0)),
+          fma(FLOAT_TYPE(data_b[b_offset + y1_idx +  1]), (q4_1 + (((data_a[ib0 + i].qh[l0 +  1] & hm1) != 0) ? 16 : 0)),
+          fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 16]), (q4_2 + (((data_a[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0)),
+             FLOAT_TYPE(data_b[b_offset + y1_idx + 17]) * (q4_3 + (((data_a[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0)))));
+        const FLOAT_TYPE sy =
+          fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), (q4_4 + (((data_a[ib0 + i].qh[l0     ] & (hm1 << 1)) != 0) ? 16 : 0)),
+          fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), (q4_5 + (((data_a[ib0 + i].qh[l0 +  1] & (hm1 << 1)) != 0) ? 16 : 0)),
+          fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 48]), (q4_6 + (((data_a[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0)),
+             FLOAT_TYPE(data_b[b_offset + y1_idx + 49]) * (q4_7 + (((data_a[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0)))));
+        const FLOAT_TYPE sz =
+          fma(FLOAT_TYPE(data_b[b_offset + y2_idx     ]), (q4_8  + (((data_a[ib0 + i].qh[l0     ] & hm2) != 0) ? 16 : 0)),
+          fma(FLOAT_TYPE(data_b[b_offset + y2_idx +  1]), (q4_9  + (((data_a[ib0 + i].qh[l0 +  1] & hm2) != 0) ? 16 : 0)),
+          fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 16]), (q4_10 + (((data_a[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0)),
+             FLOAT_TYPE(data_b[b_offset + y2_idx + 17]) * (q4_11 + (((data_a[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0)))));
+        const FLOAT_TYPE sw =
+          fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), (q4_12 + (((data_a[ib0 + i].qh[l0     ] & (hm2 << 1)) != 0) ? 16 : 0)),
+          fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), (q4_13 + (((data_a[ib0 + i].qh[l0 +  1] & (hm2 << 1)) != 0) ? 16 : 0)),
+          fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 48]), (q4_14 + (((data_a[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0)),
+             FLOAT_TYPE(data_b[b_offset + y2_idx + 49]) * (q4_15 + (((data_a[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0)))));
+        const FLOAT_TYPE smin =
+          fma(FLOAT_TYPE(data_b[b_offset + y1_idx     ]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 1 ]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 17]), sc2,
+          fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 49]), sc3,
+          fma(FLOAT_TYPE(data_b[b_offset + y2_idx     ]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 1 ]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 17]), sc6,
+              (FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 49])) * sc7)));
+        const uint tmp_idx = 16 * ix + tid;
+        tmp[tmp_idx] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, tmp[tmp_idx]));
     }
 
     // sum up partial sums and write back result
index d610cf0306b0ab1b217b792ce2d54dad2d9806b5..95c286eeb17e1f3bc539db976f1e78f36261ba55 100644 (file)
@@ -44,22 +44,22 @@ void main() {
         const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
 
 #if K_QUANTS_PER_ITERATION == 1
-        FLOAT_TYPE sum = FLOAT_TYPE(data_b[b_offset + y_idx +  0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset +  0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0x03) << 4)) - 32)
-                       + FLOAT_TYPE(data_b[b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32)
-                       + FLOAT_TYPE(data_b[b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0x0c) << 2)) - 32)
-                       + FLOAT_TYPE(data_b[b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32)
-                       + FLOAT_TYPE(data_b[b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset +  0]  >> 4) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0x30) >> 0)) - 32)
-                       + FLOAT_TYPE(data_b[b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16]  >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32)
-                       + FLOAT_TYPE(data_b[b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32]  >> 4) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0xc0) >> 2)) - 32)
-                       + FLOAT_TYPE(data_b[b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48]  >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32);
-        tmp[16 * ix + tid] += sum;
+        const uint tmp_idx = 16 * ix + tid;
+        tmp[tmp_idx] = fma(FLOAT_TYPE(data_b[b_offset + y_idx +  0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset +  0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0x03) << 4)) - 32),
+                       fma(FLOAT_TYPE(data_b[b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32),
+                       fma(FLOAT_TYPE(data_b[b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0x0c) << 2)) - 32),
+                       fma(FLOAT_TYPE(data_b[b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32),
+                       fma(FLOAT_TYPE(data_b[b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset +  0]  >> 4) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0x30) >> 0)) - 32),
+                       fma(FLOAT_TYPE(data_b[b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16]  >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32),
+                       fma(FLOAT_TYPE(data_b[b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32]  >> 4) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0xc0) >> 2)) - 32),
+                       fma(FLOAT_TYPE(data_b[b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48]  >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32), tmp[tmp_idx]))))))));
 #else
         FLOAT_TYPE sum = FLOAT_TYPE(0.0);
         [[unroll]] for (int l = 0; l < 4; ++l) {
-            sum += FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32)
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l+32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32)
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l+64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0]  >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32)
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32]  >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32);
+            sum = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32),
+                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32),
+                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0]  >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32),
+                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32]  >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32), sum))));
         }
         tmp[16 * ix + tid] += sum;
 #endif
index 5fe9d5241381e51536ff0f640b6e621f976d3e6f..fffdd18189d5586805ce7982fa11b503849a6135 100644 (file)
@@ -326,10 +326,10 @@ void main() {
                 mbyte = uint8_t((data_a[ib].scales[is + 4] >>  4) | ((data_a[ib].scales[is    ] >> 6) << 4));
             }
             const float d = loadd.x * sc;
-            const float m = loadd.y * mbyte;
+            const float m = -loadd.y * mbyte;
 
-            buf_a[buf_idx    ] = FLOAT_TYPE(d * float((data_a[ib].qs[qsi    ] >> (b * 4)) & 0xF) - m);
-            buf_a[buf_idx + 1] = FLOAT_TYPE(d * float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) - m);
+            buf_a[buf_idx    ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi    ] >> (b * 4)) & 0xF), m));
+            buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
 #elif defined(DATA_A_Q5_K)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
             const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
@@ -357,10 +357,10 @@ void main() {
                 mbyte = uint8_t((data_a[ib].scales[is + 4] >>  4) | ((data_a[ib].scales[is    ] >> 6) << 4));
             }
             const float d = loadd.x * sc;
-            const float m = loadd.y * mbyte;
+            const float m = -loadd.y * mbyte;
 
-            buf_a[buf_idx    ] = FLOAT_TYPE(d * (float((data_a[ib].qs[qsi    ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi    ] & hm) != 0 ? 16 : 0)) - m);
-            buf_a[buf_idx + 1] = FLOAT_TYPE(d * (float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0)) - m);
+            buf_a[buf_idx    ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi    ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi    ] & hm) != 0 ? 16 : 0), m));
+            buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
 #elif defined(DATA_A_Q6_K)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
             const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
@@ -463,7 +463,8 @@ void main() {
                 [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
                     [[unroll]] for (uint cc = 0; cc < TN; cc++) {
                         [[unroll]] for (uint cr = 0; cr < TM; cr++) {
-                            sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]);
+                            const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
+                            sums[sums_idx] = fma(float(cache_a[wsir * TM + cr]), float(cache_b[wsic * TN + cc]), sums[sums_idx]);
                         }
                     }
                 }
diff --git a/src/vulkan-shaders/repeat.comp b/src/vulkan-shaders/repeat.comp
new file mode 100644 (file)
index 0000000..a86af87
--- /dev/null
@@ -0,0 +1,24 @@
+#version 450
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+
+uint src0_idx_mod(uint idx) {
+    const uint i13 = idx / (p.ne12*p.ne11*p.ne10);
+    const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
+    const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10);
+    const uint i12_offset = i12*p.ne11*p.ne10;
+    const uint i11 = (idx - i13_offset - i12_offset) / p.ne10;
+    const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
+    return (i13 % p.ne03)*p.nb03 + (i12 % p.ne02)*p.nb02 + (i11 % p.ne01)*p.nb01 + (i10 % p.ne00)*p.nb00;
+}
+
+void main() {
+    const uint idx = get_idx();
+
+    if (idx >= p.ne) {
+        return;
+    }
+
+    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(data_a[src0_idx_mod(idx)]);
+}
index a451d24705136572b8f58b03f505835f6895233a..89ac99f29696bad9f42e59a1df03de3c729c31d8 100644 (file)
@@ -369,31 +369,31 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
     }));
 
     tasks.push_back(std::async(std::launch::async, [] {
-        string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
+        string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
     }));
 
     tasks.push_back(std::async(std::launch::async, [] {
-        string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+        string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
     }));
 
     tasks.push_back(std::async(std::launch::async, [] {
-        string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+        string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
     }));
 
     tasks.push_back(std::async(std::launch::async, [] {
-        string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+        string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
     }));
 
     tasks.push_back(std::async(std::launch::async, [] {
-        string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+        string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
     }));
 
     tasks.push_back(std::async(std::launch::async, [] {
-        string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+        string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
     }));
 
     tasks.push_back(std::async(std::launch::async, [] {
-        string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+        string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
     }));
 
     tasks.push_back(std::async(std::launch::async, [] {