]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan : implement Stable Diffusion operators (#904)
author0cc4m <redacted>
Sun, 4 Aug 2024 15:28:08 +0000 (17:28 +0200)
committerGitHub <redacted>
Sun, 4 Aug 2024 15:28:08 +0000 (18:28 +0300)
* Fix Vulkan repeat op

* Implement Vulkan concat op

* Delete old Vulkan shader generator

* Implement Vulkan im2col op

* Implement Vulkan unary gelu_quick op

* Implement Vulkan group_norm op

* Implement Vulkan timestep_embedding op

* Implement Vulkan upscale op

* Fix Vulkan vk_context tensor extra index issue

* Fix Vulkan matmul shader parameter bug

* Properly fix Vulkan matmul shader parameter bug

* Add Vulkan ADD f16 + f32 -> f16 operator support

* Implement Vulkan tanh op

* Fix Vulkan group count too large Validation error on non-Nvidia GPUs

* Throw error when too much memory is requested

* Fix another Vulkan group count too large Validation error on non-Nvidia GPUs

* Fix matmul MMQ condition

* Implement Vulkan pad op

* Fix Vulkan crash when tensor is used multiple times in a compute graph

* Add Vulkan CONCAT f16 + f16 -> f16 op

* Add Vulkan LEAKY_RELU op

29 files changed:
src/ggml-vulkan.cpp
src/ggml_vk_generate_shaders.py [deleted file]
src/vulkan-shaders/add.comp
src/vulkan-shaders/clamp.comp
src/vulkan-shaders/concat.comp [new file with mode: 0644]
src/vulkan-shaders/copy.comp
src/vulkan-shaders/div.comp
src/vulkan-shaders/gelu.comp
src/vulkan-shaders/gelu_quick.comp [new file with mode: 0644]
src/vulkan-shaders/generic_binary_head.comp
src/vulkan-shaders/generic_unary_head.comp
src/vulkan-shaders/group_norm.comp [new file with mode: 0644]
src/vulkan-shaders/im2col.comp [new file with mode: 0644]
src/vulkan-shaders/leaky_relu.comp [new file with mode: 0644]
src/vulkan-shaders/mul.comp
src/vulkan-shaders/norm.comp
src/vulkan-shaders/pad.comp [new file with mode: 0644]
src/vulkan-shaders/relu.comp
src/vulkan-shaders/rms_norm.comp
src/vulkan-shaders/scale.comp
src/vulkan-shaders/silu.comp
src/vulkan-shaders/soft_max.comp
src/vulkan-shaders/square.comp
src/vulkan-shaders/sum_rows.comp
src/vulkan-shaders/tanh.comp [new file with mode: 0644]
src/vulkan-shaders/timestep_embedding.comp [new file with mode: 0644]
src/vulkan-shaders/types.comp
src/vulkan-shaders/upscale.comp [new file with mode: 0644]
src/vulkan-shaders/vulkan-shaders-gen.cpp

index fa68360b96e4aaec4f6e2dae7d712a875705da4f..d7fea78d072b3aa23beb37d5ac7648555936ce42 100644 (file)
@@ -177,24 +177,33 @@ struct vk_device_struct {
     vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
     vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
     vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
+    vk_pipeline pipeline_add_f32, pipeline_add_f16_f32_f16;
     vk_pipeline pipeline_mul_f32;
     vk_pipeline pipeline_div_f32;
-    vk_pipeline pipeline_add_f32;
+    vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
+    vk_pipeline pipeline_upscale_f32;
     vk_pipeline pipeline_scale_f32;
     vk_pipeline pipeline_sqr_f32;
     vk_pipeline pipeline_clamp_f32;
+    vk_pipeline pipeline_pad_f32;
     vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
     vk_pipeline pipeline_norm_f32;
+    vk_pipeline pipeline_group_norm_f32;
     vk_pipeline pipeline_rms_norm_f32;
     vk_pipeline pipeline_gelu_f32;
+    vk_pipeline pipeline_gelu_quick_f32;
     vk_pipeline pipeline_silu_f32;
     vk_pipeline pipeline_relu_f32;
+    vk_pipeline pipeline_leaky_relu_f32;
+    vk_pipeline pipeline_tanh_f32;
     vk_pipeline pipeline_diag_mask_inf_f32;
     vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
     vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
     vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
     vk_pipeline pipeline_argsort_f32;
     vk_pipeline pipeline_sum_rows_f32;
+    vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
+    vk_pipeline pipeline_timestep_embedding_f32;
 
     std::vector<vk_pipeline_ref> pipelines;
 
@@ -320,7 +329,7 @@ struct vk_op_binary_push_constants {
     uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
     uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23;
     uint32_t d_offset;
-    float param1; float param2;
+    float param1; float param2; int32_t param3;
 };
 
 struct vk_op_diag_mask_push_constants {
@@ -358,6 +367,25 @@ struct vk_op_argsort_push_constants {
     int32_t order;
 };
 
+struct vk_op_im2col_push_constants {
+    uint32_t batch_offset; uint32_t offset_delta;
+    uint32_t IC;
+    uint32_t IW; uint32_t IH;
+    uint32_t OW; uint32_t OH;
+    uint32_t KW; uint32_t KH;
+    uint32_t pelements;
+    uint32_t CHW;
+    int32_t s0; int32_t s1;
+    int32_t p0; int32_t p1;
+    int32_t d0; int32_t d1;
+};
+
+struct vk_op_timestep_embedding_push_constants {
+    uint32_t nb1;
+    uint32_t dim;
+    uint32_t max_period;
+};
+
 // Allow pre-recording command buffers
 struct vk_staging_memcpy {
     vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -367,28 +395,32 @@ struct vk_staging_memcpy {
     size_t n;
 };
 
-struct vk_context {
-    size_t idx;
+struct vk_op_upscale_push_constants {
+    uint32_t ne; uint32_t d_offset;
+    uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
+    uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
+    float sf0; float sf1; float sf2; float sf3;
+};
 
+struct vk_context_struct {
     vk_submission * s;
     std::vector<vk_sequence> seqs;
 
-    ggml_tensor * exit_tensor;
+    int exit_tensor_idx;
 
     std::vector<vk_staging_memcpy> in_memcpys;
     std::vector<vk_staging_memcpy> out_memcpys;
 
     vk_queue * q;
 };
+typedef std::shared_ptr<vk_context_struct> vk_context;
+typedef std::weak_ptr<vk_context_struct> vk_context_ref;
 
 struct ggml_tensor_extra_gpu {
-    size_t ctx_idx;
-
     vk_buffer_ref buffer_gpu;
     uint64_t offset;
 
     void reset() {
-        ctx_idx = 0;
         buffer_gpu.reset();
         offset = 0;
     }
@@ -459,8 +491,10 @@ struct ggml_backend_vk_context {
 
     vk_buffer buffer_pool[MAX_VK_BUFFERS];
 
-    vk_context * compute_ctx;
-    vk_context * transfer_ctx;
+    vk_context_ref compute_ctx;
+    vk_context_ref transfer_ctx;
+
+    std::vector<vk_context_ref> tensor_ctxs;
 };
 
 #ifdef GGML_VULKAN_MEMORY_DEBUG
@@ -510,12 +544,12 @@ static vk_instance_t vk_instance;
 static size_t vk_skip_checks;
 static size_t vk_output_tensor;
 
-static void ggml_vk_print_tensor(ggml_backend * ctx, const ggml_tensor * tensor, const char * name);
-static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * tensor);
-static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_tensor * tensor);
+static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name);
+static void ggml_vk_check_results_0(ggml_tensor * tensor);
+static void ggml_vk_check_results_1(ggml_tensor * tensor);
 #endif
 
-typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
 
 GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend);
 
@@ -708,11 +742,11 @@ static vk_submission ggml_vk_create_submission(vk_device& device, vk_queue& q, s
     return s;
 }
 
-static void ggml_vk_submit(vk_context * ctx, vk::Fence fence) {
-    VK_LOG_DEBUG("ggml_vk_submit(" << ctx->seqs.size() << ", " << fence << ")");
+static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) {
     if (ctx->seqs.empty()) {
         return;
     }
+    VK_LOG_DEBUG("ggml_vk_submit(" << ctx << ", " << fence << ")");
 
     std::vector<std::vector<uint64_t>> tl_wait_vals;
     std::vector<std::vector<uint64_t>> tl_signal_vals;
@@ -844,21 +878,17 @@ static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_
     q.stage_flags = stage_flags;
 }
 
-static vk_context * ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_queue& q) {
-    VK_LOG_DEBUG("ggml_vk_create_context()");
-    ctx->gc.contexts.emplace_back();
-    vk_context * result = &ctx->gc.contexts[ctx->gc.contexts.size() - 1];
-    memset((void *) result, 0, sizeof(vk_context));
-    result->idx = ctx->gc.contexts.size() - 1;
+static vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_queue& q) {
+    vk_context result = std::make_shared<vk_context_struct>();
+    VK_LOG_DEBUG("ggml_vk_create_context(" << result << ")");
+    ctx->gc.contexts.emplace_back(result);
     result->q = &q;
     return result;
 }
 
-static vk_context * ggml_vk_create_temporary_context(vk_queue& q) {
-    VK_LOG_DEBUG("ggml_vk_create_temporary_context()");
-    vk_context * result = new vk_context;
-    memset((void *) result, 0, sizeof(vk_context));
-    result->idx = 0;
+static vk_context ggml_vk_create_temporary_context(vk_queue& q) {
+    vk_context result = std::make_shared<vk_context_struct>();
+    VK_LOG_DEBUG("ggml_vk_create_temporary_context(" << result << ")");
     result->q = &q;
     return result;
 }
@@ -915,6 +945,10 @@ static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_pr
 
 static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) {
     VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags) << ", " << to_string(fallback_flags) << ")");
+    if (size > device->max_memory_allocation_size) {
+        throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device memory allocation limit");
+    }
+
     std::lock_guard<std::mutex> guard(device->mutex);
 
     vk_buffer buf = std::make_shared<vk_buffer_struct>();
@@ -1027,7 +1061,7 @@ static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) {
     return { buf, 0, VK_WHOLE_SIZE };
 }
 
-static void ggml_vk_sync_buffers(vk_context * ctx) {
+static void ggml_vk_sync_buffers(vk_context& ctx) {
     VK_LOG_DEBUG("ggml_vk_sync_buffers()");
     const std::vector<vk::MemoryBarrier> mem_barriers{ { { vk::AccessFlagBits::eMemoryRead | vk::AccessFlagBits::eMemoryWrite }, { vk::AccessFlagBits::eMemoryRead | vk::AccessFlagBits::eMemoryWrite } } };
 
@@ -1041,7 +1075,7 @@ static void ggml_vk_sync_buffers(vk_context * ctx) {
     );
 }
 
-static void ggml_vk_wait_events(vk_context * ctx, std::vector<vk::Event>&& events) {
+static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events) {
     VK_LOG_DEBUG("ggml_vk_wait_events()");
     if (events.empty()) {
         return;
@@ -1598,6 +1632,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -1605,20 +1640,31 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
-
     ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
 
+    ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
+
+    ggml_vk_create_pipeline(device, device->pipeline_upscale_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {}, 1);
+
     ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
+    ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+
     ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
 
@@ -1634,6 +1680,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+
+    ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
+
+    ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
 }
 
 static vk_device ggml_vk_get_device(size_t idx) {
@@ -2077,9 +2128,6 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
     ctx->staging_size = 0;
     ctx->staging_offset = 0;
 
-    ctx->compute_ctx = nullptr;
-    ctx->transfer_ctx = nullptr;
-
 #ifdef GGML_VULKAN_CHECK_RESULTS
     const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS");
     vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks));
@@ -2112,7 +2160,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
 }
 
 static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) {
-    VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline()");
+    VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
     if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
         return ctx->device->pipeline_matmul_f32;
     }
@@ -2126,7 +2174,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
         return ctx->device->pipeline_matmul_f16;
     }
 
-    GGML_ASSERT(src1_type == GGML_TYPE_F32);
+    if (src1_type != GGML_TYPE_F32) {
+        return nullptr;
+    }
 
     switch (src0_type) {
         case GGML_TYPE_Q4_0:
@@ -2370,7 +2420,7 @@ static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bo
     return s;
 }
 
-static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline, std::vector<vk_subbuffer>&& buffers, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
+static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline, std::vector<vk_subbuffer>&& buffers, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
     const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
     const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
     const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]);
@@ -2410,7 +2460,7 @@ static void ggml_vk_end_submission(vk_submission& s, std::vector<vk_semaphore> w
     s.signal_semaphores = std::move(signal_semaphores);
 }
 
-static void ggml_vk_ctx_end(vk_context * ctx) {
+static void ggml_vk_ctx_end(vk_context& ctx) {
     VK_LOG_DEBUG("ggml_vk_ctx_end(" << ctx << ", " << ctx->seqs.size() << ")");
     if (ctx->s == nullptr) {
         return;
@@ -2420,7 +2470,7 @@ static void ggml_vk_ctx_end(vk_context * ctx) {
     ctx->s = nullptr;
 }
 
-static void ggml_vk_ctx_begin(vk_device& device, vk_context * subctx) {
+static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) {
     VK_LOG_DEBUG("ggml_vk_ctx_begin(" << device->name << ")");
     if (subctx->s != nullptr) {
         ggml_vk_ctx_end(subctx);
@@ -2453,7 +2503,7 @@ static void ggml_vk_ensure_sync_staging_buffer(vk_device& device, size_t size) {
     }
 }
 
-static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_context * subctx, vk_buffer& dst, size_t offset, const ggml_tensor * tensor, bool sync_staging = false) {
+static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_context& subctx, vk_buffer& dst, size_t offset, const ggml_tensor * tensor, bool sync_staging = false) {
     VK_LOG_DEBUG("ggml_vk_buffer_write_nc_async(" << tensor << ")");
     GGML_ASSERT(!ggml_is_contiguous(tensor));
     // Buffer is already mapped
@@ -2558,7 +2608,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
     }
 }
 
-static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
+static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
     VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")");
     // Buffer is already mapped
     if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
@@ -2623,7 +2673,7 @@ static void ggml_vk_buffer_write_2d_async(vk_context * subctx, vk_buffer& dst, s
     }
 }
 
-static void ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
+static void ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
     VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")");
     return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, staging_buffer, staging_offset, sync_staging);
 }
@@ -2638,7 +2688,7 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void *
             memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width);
         }
     } else {
-        vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue);
+        vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue);
         ggml_vk_ctx_begin(dst->device, subctx);
         ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, nullptr, 0, true);
         ggml_vk_ctx_end(subctx);
@@ -2650,8 +2700,6 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void *
         ggml_vk_submit(subctx, dst->device->fence);
         VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences");
         dst->device->device.resetFences({ dst->device->fence });
-
-        delete subctx;
     }
 }
 
@@ -2660,12 +2708,14 @@ static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src
     ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1);
 }
 
-static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
+static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
     VK_LOG_DEBUG("ggml_vk_buffer_read_2d_async(offset=" << offset << ", width=" << width << ", height=" << height << ")");
     GGML_ASSERT(width > 0);
     GGML_ASSERT(height > 0);
     GGML_ASSERT(src != nullptr);
 
+    // TODO: staging_offset is not used
+
     // Check if dst is pinned memory
     vk_buffer buf = nullptr;
     size_t buf_offset;
@@ -2714,18 +2764,18 @@ static void ggml_vk_buffer_read_2d_async(vk_context * subctx, vk_buffer& src, si
     deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys);
 }
 
-static void ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t size, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
+static void ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t size, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
     return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, staging_buffer, staging_offset, sync_staging);
 }
 
 static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) {
-    VK_LOG_DEBUG("ggml_vk_buffer_read(" << offset << ", " << size << ")");
+    VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")");
     if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
         GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);
 
         memcpy(dst, (uint8_t *) src->ptr + offset, size);
     } else {
-        vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue);
+        vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue);
         ggml_vk_ctx_begin(src->device, subctx);
         ggml_vk_buffer_read_async(subctx, src, offset, dst, size, nullptr, 0, true);
         ggml_vk_ctx_end(subctx);
@@ -2737,12 +2787,10 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_
         for (auto& cpy : subctx->out_memcpys) {
             memcpy(cpy.dst, cpy.src, cpy.n);
         }
-
-        delete subctx;
     }
 }
 
-static void ggml_vk_buffer_copy_async(vk_context * ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
+static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
     VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")");
     // Make sure both buffers are on same device
     GGML_ASSERT(src->device == dst->device);
@@ -2756,15 +2804,13 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
     if (src->device == dst->device) {
         VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")");
         // Copy within the device
-        vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue);
+        vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue);
         ggml_vk_ctx_begin(src->device, subctx);
         ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size);
         ggml_vk_ctx_end(subctx);
         ggml_vk_submit(subctx, src->device->fence);
         VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences");
         src->device->device.resetFences({ src->device->fence });
-
-        delete subctx;
     } else {
         VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")");
         // Copy device to device
@@ -2783,7 +2829,7 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
 static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
     VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
 
-    vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue);
+    vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue);
     ggml_vk_ctx_begin(dst->device, subctx);
     subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
     ggml_vk_ctx_end(subctx);
@@ -2791,8 +2837,6 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
     ggml_vk_submit(subctx, dst->device->fence);
     VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_memset waitForFences");
     dst->device->device.resetFences({ dst->device->fence });
-
-    delete subctx;
 }
 
 static uint32_t ggml_vk_guess_split_k(int m, int n, int k) {
@@ -2855,7 +2899,7 @@ static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ct
 }
 
 static void ggml_vk_matmul(
-        ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline,
+        ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
         vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer,
         uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
         uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
@@ -2879,7 +2923,7 @@ static void ggml_vk_matmul(
 }
 
 static void ggml_vk_matmul_id(
-        ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline,
+        ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
         vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
         uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
         uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
@@ -2916,7 +2960,7 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, ggml_
     GGML_ABORT("fatal error");
 }
 
-static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline pipeline, const ggml_tensor * tensor, vk_subbuffer&& in, vk_subbuffer&& out) {
+static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor, vk_subbuffer&& in, vk_subbuffer&& out) {
     VK_LOG_DEBUG("ggml_vk_cpy_to_contiguous((" << tensor << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), ";
     std::cerr << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")");
     const int tensor_type_size = ggml_type_size(tensor->type);
@@ -2934,7 +2978,7 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context
     ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, { ne, 1, 1 });
 }
 
-static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
     std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
     std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
@@ -3107,7 +3151,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
     );  // NOLINT
 }
 
-static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     VK_LOG_DEBUG("ggml_vk_mul_mat_vec_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
     std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
     std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
@@ -3272,7 +3316,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
                               sizeof(vk_mat_vec_push_constants), &pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
 }
 
-static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     VK_LOG_DEBUG("ggml_vk_mul_mat_p021_f16_f32(" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
     std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
     std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
@@ -3343,7 +3387,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
     ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
 }
 
-static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     VK_LOG_DEBUG("ggml_vk_mul_mat_nc_f16_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
     std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
     std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
@@ -3418,7 +3462,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
     ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
 }
 
-static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")");
     if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1) {
         ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, src0, src1, dst);
@@ -3431,7 +3475,7 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx,
     }
 }
 
-static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
+static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
     VK_LOG_DEBUG("ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
     std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
     std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
@@ -3618,7 +3662,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context *
     );  // NOLINT
 }
 
-static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
+static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
     VK_LOG_DEBUG("ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
     std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
     std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
@@ -3794,7 +3838,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
         sizeof(vk_mat_vec_id_push_constants), &pc, { groups_x, (uint32_t)nei0, groups_z });
 }
 
-static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
+static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
     VK_LOG_DEBUG("ggml_vk_mul_mat_id(" << src0 << ", " << src1 << ", " << src2 << ", " << dst << ")");
     if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
         ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst);
@@ -3803,8 +3847,8 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context * subct
     }
 }
 
-static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    // guaranteed to be an integer due to the check in ggml_can_repeat
+static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    VK_LOG_DEBUG("ggml_vk_op_repeat(" << src0 << ", " << src1 << ", " << dst << ")");
     const uint64_t ne0 = dst->ne[0];
     const uint64_t ne1 = dst->ne[1];
     const uint64_t ne2 = dst->ne[2];
@@ -3825,6 +3869,7 @@ static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx
     const uint64_t nb02 = src0->nb[2];
     const uint64_t nb03 = src0->nb[3];
 
+    // guaranteed to be an integer due to the check in ggml_can_repeat
     const uint64_t nr0 = ne0/ne00;
     const uint64_t nr1 = ne1/ne01;
     const uint64_t nr2 = ne2/ne02;
@@ -3852,8 +3897,8 @@ static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx
                         for     (uint64_t k1 = 0; k1 < ne01; k1++) {
                             for (uint64_t i0 = 0; i0 < nr0;  i0++) {
                                 copies.push_back({
-                                    src_offset + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0,
-                                    dst_offset + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01,
+                                    src_offset + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01,
+                                    dst_offset + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0,
                                     ne00*nb0,
                                 });
                             }
@@ -3874,11 +3919,6 @@ static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx
 
 static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
     switch (op) {
-    case GGML_OP_ADD:
-        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_add_f32;
-        }
-        return nullptr;
     case GGML_OP_GET_ROWS:
         GGML_ASSERT(src1->type == GGML_TYPE_I32);
         if (dst->type == GGML_TYPE_F16) {
@@ -3888,6 +3928,14 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_get_rows_f32[src0->type];
         }
         return nullptr;
+    case GGML_OP_ADD:
+        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+            return ctx->device->pipeline_add_f32;
+        }
+        if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
+            return ctx->device->pipeline_add_f16_f32_f16;
+        }
+        return nullptr;
     case GGML_OP_MUL:
         if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
             return ctx->device->pipeline_mul_f32;
@@ -3898,6 +3946,22 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_div_f32;
         }
         return nullptr;
+    case GGML_OP_CONCAT:
+        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+            return ctx->device->pipeline_concat_f32;
+        }
+        if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+            return ctx->device->pipeline_concat_f16;
+        }
+        if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
+            return ctx->device->pipeline_concat_i32;
+        }
+        return nullptr;
+    case GGML_OP_UPSCALE:
+        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+            return ctx->device->pipeline_upscale_f32;
+        }
+        return nullptr;
     case GGML_OP_SCALE:
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
             return ctx->device->pipeline_scale_f32;
@@ -3913,6 +3977,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_clamp_f32;
         }
         return nullptr;
+    case GGML_OP_PAD:
+        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+            return ctx->device->pipeline_pad_f32;
+        }
+        return nullptr;
     case GGML_OP_CPY:
     case GGML_OP_CONT:
     case GGML_OP_DUP:
@@ -3922,6 +3991,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_norm_f32;
         }
         return nullptr;
+    case GGML_OP_GROUP_NORM:
+        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+            return ctx->device->pipeline_group_norm_f32;
+        }
+        return nullptr;
     case GGML_OP_RMS_NORM:
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
             return ctx->device->pipeline_rms_norm_f32;
@@ -3939,11 +4013,21 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
                     return ctx->device->pipeline_gelu_f32;
                 }
                 break;
+            case GGML_UNARY_OP_GELU_QUICK:
+                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+                    return ctx->device->pipeline_gelu_quick_f32;
+                }
+                break;
             case GGML_UNARY_OP_RELU:
                 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
                     return ctx->device->pipeline_relu_f32;
                 }
                 break;
+            case GGML_UNARY_OP_TANH:
+                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+                    return ctx->device->pipeline_tanh_f32;
+                }
+                break;
             default:
                 break;
         }
@@ -3995,6 +4079,24 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_sum_rows_f32;
         }
         return nullptr;
+    case GGML_OP_IM2COL:
+        if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+            return ctx->device->pipeline_im2col_f32;
+        }
+        if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
+            return ctx->device->pipeline_im2col_f32_f16;
+        }
+        return nullptr;
+    case GGML_OP_TIMESTEP_EMBEDDING:
+        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+            return ctx->device->pipeline_timestep_embedding_f32;
+        }
+        return nullptr;
+    case GGML_OP_LEAKY_RELU:
+        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+            return ctx->device->pipeline_leaky_relu_f32;
+        }
+        return nullptr;
     default:
         return nullptr;
     }
@@ -4018,9 +4120,12 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
     case GGML_OP_ADD:
     case GGML_OP_MUL:
     case GGML_OP_DIV:
+    case GGML_OP_CONCAT:
+    case GGML_OP_UPSCALE:
     case GGML_OP_SCALE:
     case GGML_OP_SQR:
     case GGML_OP_CLAMP:
+    case GGML_OP_PAD:
         return true;
     default:
         return false;
@@ -4028,7 +4133,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
 }
 
 template<typename PC>
-static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const PC&& pc) {
+static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const PC&& pc) {
     VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
     if (src1 != nullptr) {
         std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
@@ -4124,7 +4229,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
     vk_buffer d_D = extra->buffer_gpu.lock();
 
     // Workaround for tiny tensor inputs on ROPE
-    if (use_src1 && y_sz > d_D->size) {
+    if (op == GGML_OP_ROPE && use_src1 && y_sz > d_D->size) {
         y_sz = VK_WHOLE_SIZE;
     }
 
@@ -4173,13 +4278,26 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
     if (op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))) {
         ggml_pipeline_allocate_descriptor_sets(ctx->device, pipeline, 1);
 
-        switch (dst->op) {
+        switch (op) {
         case GGML_OP_NORM:
         case GGML_OP_RMS_NORM:
         case GGML_OP_SOFT_MAX:
         case GGML_OP_SUM_ROWS:
-            elements = { (uint32_t)ggml_nrows(src0), 1, 1 };
-            break;
+            {
+                const uint32_t nr = ggml_nrows(src0);
+                if (nr > 262144) {
+                    elements = { 512, 512, CEIL_DIV(nr, 262144) };
+                } else if (nr > 512) {
+                    elements = { 512, CEIL_DIV(nr, 512), 1 };
+                } else {
+                    elements = { nr, 1, 1 };
+                }
+            } break;
+        case GGML_OP_GROUP_NORM:
+            {
+                const uint32_t num_groups = dst->op_params[0];
+                elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 };
+            } break;
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_ROPE:
             elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
@@ -4190,6 +4308,49 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
         case GGML_OP_ARGSORT:
             elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
             break;
+        case GGML_OP_IM2COL:
+            {
+                const bool is_2D = dst->op_params[6] == 1;
+
+                const uint32_t IC = src1->ne[is_2D ? 2 : 1];
+
+                const uint32_t KH = is_2D ? src0->ne[1] : 1;
+                const uint32_t KW =         src0->ne[0];
+
+                const uint32_t OH = is_2D ? dst->ne[2] : 1;
+                const uint32_t OW =         dst->ne[1];
+
+                const uint32_t batch = src1->ne[3];
+
+                elements = { OW * KW * KH, OH, batch * IC };
+            } break;
+        case GGML_OP_TIMESTEP_EMBEDDING:
+            {
+                const uint32_t dim = dst->op_params[0];
+                uint32_t half_ceil = (dim + 1) / 2;
+                elements = { half_ceil, (uint32_t)src0->ne[0], 1 };
+            } break;
+        case GGML_OP_ADD:
+        case GGML_OP_DIV:
+        case GGML_OP_MUL:
+        case GGML_OP_SCALE:
+        case GGML_OP_SQR:
+        case GGML_OP_CLAMP:
+        case GGML_OP_PAD:
+        case GGML_OP_CPY:
+        case GGML_OP_CONCAT:
+        case GGML_OP_UPSCALE:
+        case GGML_OP_UNARY:
+            {
+                const uint32_t ne = ggml_nelements(dst);
+                if (ne > 262144) {
+                    elements = { 512, 512, CEIL_DIV(ne, 262144) };
+                } else if (ne > 512) {
+                    elements = { 512, CEIL_DIV(ne, 512), 1 };
+                } else {
+                    elements = { ne, 1, 1 };
+                }
+            } break;
         default:
             elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
             break;
@@ -4216,7 +4377,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
             if (use_src1) {
                 subbuf_y = { d_Y, y_buf_offset, y_sz };
             } else {
-                subbuf_y = { d_X, 0, d_X->size };
+                subbuf_y = { d_X, 0, x_sz };
             }
 
             ggml_vk_sync_buffers(subctx);
@@ -4227,11 +4388,15 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
             if (use_src2) {
                 subbuf_z = { d_Z, z_buf_offset, z_sz };
             } else {
-                subbuf_z = { d_X, 0, d_X->size };
+                subbuf_z = { d_X, 0, x_sz };
             }
 
             ggml_vk_sync_buffers(subctx);
             ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, subbuf_z, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
+        } else if (op == GGML_OP_IM2COL) {
+            // im2col uses only src1 and dst buffers
+            ggml_vk_sync_buffers(subctx);
+            ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_Y, y_buf_offset, y_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
         } else if (use_src2) {
             ggml_vk_sync_buffers(subctx);
             ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_Z, z_buf_offset, z_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
@@ -4249,8 +4414,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
 
         ggml_pipeline_allocate_descriptor_sets(ctx->device, pipeline, ne02 * ne03);
 
-        switch (dst->op) {
+        switch (op) {
         case GGML_OP_NORM:
+        case GGML_OP_GROUP_NORM:
         case GGML_OP_RMS_NORM:
             elements = { (uint32_t)ne01, 1, 1 };
             break;
@@ -4286,11 +4452,11 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
     }
 }
 
-static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_REPEAT, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
+static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, {});
 }
 
-static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const uint32_t src0_type_size = ggml_type_size(src0->type);
     const uint32_t src1_type_size = ggml_type_size(src1->type);
     const uint32_t dst_type_size = ggml_type_size(dst->type);
@@ -4301,11 +4467,11 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context * subctx,
         (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
         (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
         0,
-        0.0f, 0.0f,
+        0.0f, 0.0f, 0,
     });
 }
 
-static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const uint32_t src0_type_size = ggml_type_size(src0->type);
     const uint32_t src1_type_size = ggml_type_size(src1->type);
     const uint32_t dst_type_size = ggml_type_size(dst->type);
@@ -4316,11 +4482,11 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context * subctx, cons
         (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
         (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
         0,
-        0.0f, 0.0f,
+        0.0f, 0.0f, 0,
     });
 }
 
-static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const uint32_t src0_type_size = ggml_type_size(src0->type);
     const uint32_t src1_type_size = ggml_type_size(src1->type);
     const uint32_t dst_type_size = ggml_type_size(dst->type);
@@ -4331,11 +4497,11 @@ static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context * subctx, cons
         (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
         (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
         0,
-        0.0f, 0.0f,
+        0.0f, 0.0f, 0,
     });
 }
 
-static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const uint32_t src0_type_size = ggml_type_size(src0->type);
     const uint32_t src1_type_size = ggml_type_size(src1->type);
     const uint32_t dst_type_size = ggml_type_size(dst->type);
@@ -4346,11 +4512,44 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context * subctx, cons
         (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
         (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
         0,
-        0.0f, 0.0f,
+        0.0f, 0.0f, 0,
     });
 }
 
-static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    int * op_params = (int *)dst->op_params;
+
+    const uint32_t src0_type_size = ggml_type_size(src0->type);
+    const uint32_t src1_type_size = ggml_type_size(src1->type);
+    const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONCAT, {
+        (uint32_t)ggml_nelements(dst),
+        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
+        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
+        0,
+        0.0f, 0.0f, op_params[0],
+    });
+}
+
+static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+    const uint32_t src0_type_size = ggml_type_size(src0->type);
+
+    const float sf0 = (float)dst->ne[0] / src0->ne[0];
+    const float sf1 = (float)dst->ne[1] / src0->ne[1];
+    const float sf2 = (float)dst->ne[2] / src0->ne[2];
+    const float sf3 = (float)dst->ne[3] / src0->ne[3];
+
+    ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
+        (uint32_t)ggml_nelements(dst), 0,
+        (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+        (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
+        sf0, sf1, sf2, sf3,
+    });
+}
+
+static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     float * op_params = (float *)dst->op_params;
     const uint32_t src0_type_size = ggml_type_size(src0->type);
     const uint32_t dst_type_size = ggml_type_size(dst->type);
@@ -4364,7 +4563,7 @@ static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context * subctx, co
     });
 }
 
-static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     const uint32_t src0_type_size = ggml_type_size(src0->type);
     const uint32_t dst_type_size = ggml_type_size(dst->type);
 
@@ -4377,7 +4576,7 @@ static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context * subctx, cons
     });
 }
 
-static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     float * op_params = (float *)dst->op_params;
     const uint32_t src0_type_size = ggml_type_size(src0->type);
     const uint32_t dst_type_size = ggml_type_size(dst->type);
@@ -4391,7 +4590,20 @@ static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context * subctx, co
     });
 }
 
-static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+    const uint32_t src0_type_size = ggml_type_size(src0->type);
+    const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+    ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, {
+        (uint32_t)ggml_nelements(dst),
+        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
+        0,
+        0.0f, 0.0f,
+    });
+}
+
+static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
     const uint32_t src0_type_size = ggml_type_size(src0->type);
     const uint32_t dst_type_size = ggml_type_size(dst->type);
@@ -4406,27 +4618,37 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context * subctx, cons
     });
 }
 
-static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     float * op_params = (float *)dst->op_params;
 
     ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
 }
 
-static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+    int * op_params = (int *)dst->op_params;
+
+    uint32_t num_groups = op_params[0];
+    uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
+    static const float eps = 1e-6f;
+
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f });
+}
+
+static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     float * op_params = (float *)dst->op_params;
     ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
 }
 
-static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
 }
 
-static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     int32_t * op_params = (int32_t *)dst->op_params;
     ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] });
 }
 
-static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     float * op_params = (float *)dst->op_params;
 
     float scale = op_params[0];
@@ -4451,7 +4673,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx,
     });
 }
 
-static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
+static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
     const int n_dims        = ((int32_t *) dst->op_params)[1];
     // const int mode          = ((int32_t *) dst->op_params)[2];
     // const int n_ctx         = ((int32_t *) dst->op_params)[3];
@@ -4475,7 +4697,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
     });
 }
 
-static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     int32_t * op_params = (int32_t *)dst->op_params;
 
     uint32_t ncols = src0->ne[0];
@@ -4494,10 +4716,59 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context * subctx,
     });
 }
 
-static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f });
 }
 
+static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    const int32_t s0 = dst->op_params[0];
+    const int32_t s1 = dst->op_params[1];
+    const int32_t p0 = dst->op_params[2];
+    const int32_t p1 = dst->op_params[3];
+    const int32_t d0 = dst->op_params[4];
+    const int32_t d1 = dst->op_params[5];
+
+    const bool is_2D = dst->op_params[6] == 1;
+
+    const uint32_t IC = src1->ne[is_2D ? 2 : 1];
+    const uint32_t IH = is_2D ? src1->ne[1] : 1;
+    const uint32_t IW =         src1->ne[0];
+
+    const uint32_t KH = is_2D ? src0->ne[1] : 1;
+    const uint32_t KW =         src0->ne[0];
+
+    const uint32_t OH = is_2D ? dst->ne[2] : 1;
+    const uint32_t OW =         dst->ne[1];
+
+    const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
+    const uint32_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
+
+    const uint32_t pelements = OW * KW * KH;
+
+    ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, {
+        batch_offset, offset_delta,
+        IC, IW, IH, OW, OH, KW, KH,
+        pelements,
+        IC * KH * KW,
+        s0, s1, p0, p1, d0, d1,
+    });
+}
+
+static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+    const uint32_t dim = dst->op_params[0];
+    const uint32_t max_period = dst->op_params[1];
+    const uint32_t nb1 = dst->nb[1] / ggml_type_size(dst->type);
+
+    ggml_vk_op_f32<vk_op_timestep_embedding_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, {
+        nb1, dim, max_period,
+    });
+}
+
+static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+    const float * op_params = (const float *)dst->op_params;
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f });
+}
+
 #ifdef GGML_VULKAN_RUN_TESTS
 static void ggml_vk_print_matrix_area(const void * data, ggml_type type, int ne0, int ne1, int i0, int i1, int i2) {
     if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16) {
@@ -4686,7 +4957,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
     ggml_vk_buffer_write(d_X, 0, x, sizeof(X_TYPE) * k * m * batch);
     ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch);
 
-    vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
+    vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
     for (size_t i = 0; i < num_it; i++) {
         ggml_vk_ctx_begin(ctx->device, subctx);
         ggml_vk_matmul(
@@ -4894,7 +5165,7 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
 
     ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
 
-    vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
+    vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
     ggml_vk_ctx_begin(ctx->device, subctx);
     const std::vector<uint32_t> pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne };
     ggml_vk_dispatch_pipeline(ctx, subctx, p, { { qx_buf, 0, qx_sz }, { x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1});
@@ -5027,7 +5298,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
     ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
     ggml_vk_buffer_write(y_buf, 0, y, y_sz);
 
-    vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
+    vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
     for (size_t i = 0; i < num_it; i++) {
         ggml_vk_ctx_begin(ctx->device, subctx);
         ggml_vk_matmul(
@@ -5175,7 +5446,7 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
 
     const bool y_f32_kernel = use_src1 && src1->type == GGML_TYPE_F32 && !y_non_contig;
 
-    bool mmp = (use_src0 && use_src1 && src1_type == GGML_TYPE_F32) ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0_type, y_non_contig ? GGML_TYPE_F16 : src1->type) != nullptr : false;
+    bool mmp = (use_src0 && use_src1 && (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID)) ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type) != nullptr : false;
 
     const bool qx_needs_dequant = use_src0 && (!mmp || x_non_contig);
     const bool qy_needs_dequant = use_src1 && ((src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig);
@@ -5211,24 +5482,33 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
     case GGML_OP_SCALE:
     case GGML_OP_SQR:
     case GGML_OP_CLAMP:
+    case GGML_OP_PAD:
     case GGML_OP_CPY:
     case GGML_OP_CONT:
     case GGML_OP_DUP:
     case GGML_OP_MUL:
     case GGML_OP_DIV:
+    case GGML_OP_CONCAT:
+    case GGML_OP_UPSCALE:
     case GGML_OP_NORM:
+    case GGML_OP_GROUP_NORM:
     case GGML_OP_RMS_NORM:
     case GGML_OP_DIAG_MASK_INF:
     case GGML_OP_SOFT_MAX:
     case GGML_OP_ROPE:
     case GGML_OP_ARGSORT:
     case GGML_OP_SUM_ROWS:
+    case GGML_OP_IM2COL:
+    case GGML_OP_TIMESTEP_EMBEDDING:
+    case GGML_OP_LEAKY_RELU:
         break;
     case GGML_OP_UNARY:
         switch (ggml_get_unary_op(node)) {
         case GGML_UNARY_OP_SILU:
         case GGML_UNARY_OP_GELU:
+        case GGML_UNARY_OP_GELU_QUICK:
         case GGML_UNARY_OP_RELU:
+        case GGML_UNARY_OP_TANH:
             break;
         default:
             return;
@@ -5236,6 +5516,13 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
         break;
     case GGML_OP_MUL_MAT:
     case GGML_OP_MUL_MAT_ID:
+        if (
+                x_sz > ctx->device->max_memory_allocation_size ||
+                y_sz > ctx->device->max_memory_allocation_size ||
+                d_sz > ctx->device->max_memory_allocation_size ||
+                split_k_size > ctx->device->max_memory_allocation_size) {
+            GGML_ABORT("Requested preallocation size is too large");
+        }
         if (ctx->prealloc_size_x < x_sz) {
             ctx->prealloc_size_x = x_sz;
         }
@@ -5430,7 +5717,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
     }
 }
 
-static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, bool last_node){
+static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, bool last_node){
     ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) node->extra;
 
     if (ggml_is_empty(node) || extra == nullptr) {
@@ -5457,7 +5744,9 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
         switch (ggml_get_unary_op(node)) {
         case GGML_UNARY_OP_SILU:
         case GGML_UNARY_OP_GELU:
+        case GGML_UNARY_OP_GELU_QUICK:
         case GGML_UNARY_OP_RELU:
+        case GGML_UNARY_OP_TANH:
             break;
         default:
             return;
@@ -5468,13 +5757,17 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
     case GGML_OP_ADD:
     case GGML_OP_MUL:
     case GGML_OP_DIV:
+    case GGML_OP_CONCAT:
+    case GGML_OP_UPSCALE:
     case GGML_OP_SCALE:
     case GGML_OP_SQR:
     case GGML_OP_CLAMP:
+    case GGML_OP_PAD:
     case GGML_OP_CPY:
     case GGML_OP_CONT:
     case GGML_OP_DUP:
     case GGML_OP_NORM:
+    case GGML_OP_GROUP_NORM:
     case GGML_OP_RMS_NORM:
     case GGML_OP_DIAG_MASK_INF:
     case GGML_OP_SOFT_MAX:
@@ -5483,6 +5776,9 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
     case GGML_OP_MUL_MAT_ID:
     case GGML_OP_ARGSORT:
     case GGML_OP_SUM_ROWS:
+    case GGML_OP_IM2COL:
+    case GGML_OP_TIMESTEP_EMBEDDING:
+    case GGML_OP_LEAKY_RELU:
         break;
     default:
         std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
@@ -5490,102 +5786,137 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
         return;
     }
 
-    if (ctx->compute_ctx == nullptr) {
-        ctx->compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
-        ggml_vk_ctx_begin(ctx->device, ctx->compute_ctx);
+    vk_context compute_ctx;
+
+    if (ctx->compute_ctx.expired()) {
+        compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
+        ctx->compute_ctx = compute_ctx;
+        ggml_vk_ctx_begin(ctx->device, compute_ctx);
+    } else {
+        compute_ctx = ctx->compute_ctx.lock();
     }
 
     switch (node->op) {
     case GGML_OP_REPEAT:
-        ggml_vk_repeat(ctx, ctx->compute_ctx, src0, src1, node);
+        ggml_vk_repeat(ctx, compute_ctx, src0, node);
 
         break;
     case GGML_OP_GET_ROWS:
-        ggml_vk_get_rows(ctx, ctx->compute_ctx, src0, src1, node);
+        ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node);
 
         break;
     case GGML_OP_ADD:
-        ggml_vk_add(ctx, ctx->compute_ctx, src0, src1, node);
+        ggml_vk_add(ctx, compute_ctx, src0, src1, node);
 
         break;
     case GGML_OP_MUL:
-        ggml_vk_mul(ctx, ctx->compute_ctx, src0, src1, node);
+        ggml_vk_mul(ctx, compute_ctx, src0, src1, node);
 
         break;
     case GGML_OP_DIV:
-        ggml_vk_div(ctx, ctx->compute_ctx, src0, src1, node);
+        ggml_vk_div(ctx, compute_ctx, src0, src1, node);
+
+        break;
+    case GGML_OP_CONCAT:
+        ggml_vk_concat(ctx, compute_ctx, src0, src1, node);
+
+        break;
+    case GGML_OP_UPSCALE:
+        ggml_vk_upscale(ctx, compute_ctx, src0, node);
 
         break;
     case GGML_OP_SCALE:
-        ggml_vk_scale(ctx, ctx->compute_ctx, src0, node);
+        ggml_vk_scale(ctx, compute_ctx, src0, node);
 
         break;
     case GGML_OP_SQR:
-        ggml_vk_sqr(ctx, ctx->compute_ctx, src0, node);
+        ggml_vk_sqr(ctx, compute_ctx, src0, node);
 
         break;
     case GGML_OP_CLAMP:
-        ggml_vk_clamp(ctx, ctx->compute_ctx, src0, node);
+        ggml_vk_clamp(ctx, compute_ctx, src0, node);
+
+        break;
+    case GGML_OP_PAD:
+        ggml_vk_pad(ctx, compute_ctx, src0, node);
 
         break;
     case GGML_OP_CPY:
     case GGML_OP_CONT:
     case GGML_OP_DUP:
-        ggml_vk_cpy(ctx, ctx->compute_ctx, src0, node);
+        ggml_vk_cpy(ctx, compute_ctx, src0, node);
 
         break;
     case GGML_OP_NORM:
-        ggml_vk_norm(ctx, ctx->compute_ctx, src0, node);
+        ggml_vk_norm(ctx, compute_ctx, src0, node);
+
+        break;
+    case GGML_OP_GROUP_NORM:
+        ggml_vk_group_norm(ctx, compute_ctx, src0, node);
 
         break;
     case GGML_OP_RMS_NORM:
-        ggml_vk_rms_norm(ctx, ctx->compute_ctx, src0, node);
+        ggml_vk_rms_norm(ctx, compute_ctx, src0, node);
 
         break;
     case GGML_OP_UNARY:
         switch (ggml_get_unary_op(node)) {
         case GGML_UNARY_OP_SILU:
         case GGML_UNARY_OP_GELU:
+        case GGML_UNARY_OP_GELU_QUICK:
         case GGML_UNARY_OP_RELU:
-            ggml_vk_unary(ctx, ctx->compute_ctx, src0, node);
+        case GGML_UNARY_OP_TANH:
+            ggml_vk_unary(ctx, compute_ctx, src0, node);
             break;
         default:
             return;
         }
         break;
     case GGML_OP_DIAG_MASK_INF:
-        ggml_vk_diag_mask_inf(ctx, ctx->compute_ctx, src0, node);
+        ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node);
 
         break;
     case GGML_OP_SOFT_MAX:
-        ggml_vk_soft_max(ctx, ctx->compute_ctx, src0, src1, node);
+        ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node);
 
         break;
     case GGML_OP_ROPE:
-        ggml_vk_rope(ctx, ctx->compute_ctx, src0, src1, src2, node);
+        ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node);
 
         break;
     case GGML_OP_ARGSORT:
-        ggml_vk_argsort(ctx, ctx->compute_ctx, src0, node);
+        ggml_vk_argsort(ctx, compute_ctx, src0, node);
 
         break;
     case GGML_OP_SUM_ROWS:
-        ggml_vk_sum_rows(ctx, ctx->compute_ctx, src0, node);
+        ggml_vk_sum_rows(ctx, compute_ctx, src0, node);
+
+        break;
+    case GGML_OP_IM2COL:
+        ggml_vk_im2col(ctx, compute_ctx, src0, src1, node);
+
+        break;
+    case GGML_OP_TIMESTEP_EMBEDDING:
+        ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node);
+
+        break;
+    case GGML_OP_LEAKY_RELU:
+        ggml_vk_leaky_relu(ctx, compute_ctx, src0, node);
 
         break;
     case GGML_OP_MUL_MAT:
-        ggml_vk_mul_mat(ctx, ctx->compute_ctx, src0, src1, node);
+        ggml_vk_mul_mat(ctx, compute_ctx, src0, src1, node);
 
         break;
     case GGML_OP_MUL_MAT_ID:
-        ggml_vk_mul_mat_id(ctx, ctx->compute_ctx, src0, src1, src2, node);
+        ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node);
 
         break;
     default:
         return;
     }
 
-    extra->ctx_idx = ctx->compute_ctx->idx;
+    ctx->tensor_ctxs[node_idx] = compute_ctx;
 
 #ifdef GGML_VULKAN_CHECK_RESULTS
     // Force context reset on each node so that each tensor ends up in its own context
@@ -5594,13 +5925,13 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
 #endif
 
     if (last_node) {
-        ggml_vk_ctx_end(ctx->compute_ctx);
-        ctx->compute_ctx->exit_tensor = node;
-        ctx->compute_ctx = nullptr;
+        ggml_vk_ctx_end(compute_ctx);
+        compute_ctx->exit_tensor_idx = node_idx;
+        ctx->compute_ctx.reset();
     }
 }
 
-static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor){
+static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx){
     ggml_tensor_extra_gpu * extra = nullptr;
 
     switch (tensor->op) {
@@ -5608,13 +5939,17 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
     case GGML_OP_GET_ROWS:
     case GGML_OP_MUL:
     case GGML_OP_DIV:
+    case GGML_OP_CONCAT:
+    case GGML_OP_UPSCALE:
     case GGML_OP_SCALE:
     case GGML_OP_SQR:
     case GGML_OP_CLAMP:
+    case GGML_OP_PAD:
     case GGML_OP_CPY:
     case GGML_OP_CONT:
     case GGML_OP_DUP:
     case GGML_OP_NORM:
+    case GGML_OP_GROUP_NORM:
     case GGML_OP_RMS_NORM:
     case GGML_OP_DIAG_MASK_INF:
     case GGML_OP_SOFT_MAX:
@@ -5626,6 +5961,10 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
     case GGML_OP_NONE:
     case GGML_OP_ARGSORT:
     case GGML_OP_SUM_ROWS:
+    case GGML_OP_IM2COL:
+    case GGML_OP_TIMESTEP_EMBEDDING:
+    case GGML_OP_LEAKY_RELU:
+    case GGML_OP_REPEAT:
         extra = (ggml_tensor_extra_gpu *) tensor->extra;
 
         break;
@@ -5633,7 +5972,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
         switch (ggml_get_unary_op(tensor)) {
         case GGML_UNARY_OP_SILU:
         case GGML_UNARY_OP_GELU:
+        case GGML_UNARY_OP_GELU_QUICK:
         case GGML_UNARY_OP_RELU:
+        case GGML_UNARY_OP_TANH:
             extra = (ggml_tensor_extra_gpu *) tensor->extra;
             break;
         default:
@@ -5656,31 +5997,31 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
     VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")");
 
 #ifdef GGML_VULKAN_CHECK_RESULTS
-    ggml_vk_check_results_0(ctx, tensor);
+    ggml_vk_check_results_0(tensor);
 #endif
 
-    vk_context& subctx = ctx->gc.contexts[extra->ctx_idx];
+    vk_context subctx = ctx->tensor_ctxs[tensor_idx].lock();
 
     // Only run if ctx hasn't been submitted yet
-    if (!subctx.seqs.empty()) {
+    if (!subctx->seqs.empty()) {
         // Do staging buffer copies
-        for (auto& cpy : subctx.in_memcpys) {
+        for (auto& cpy : subctx->in_memcpys) {
             memcpy(cpy.dst, cpy.src, cpy.n);
         }
 
-        ggml_vk_submit(&subctx, ctx->fence);
+        ggml_vk_submit(subctx, ctx->fence);
     }
 
-    if (tensor == subctx.exit_tensor) {
+    if (tensor_idx == subctx->exit_tensor_idx) {
         VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences");
         ctx->device->device.resetFences({ ctx->fence });
 
         // Do staging buffer copies
-        for (auto& cpy : subctx.out_memcpys) {
+        for (auto& cpy : subctx->out_memcpys) {
             memcpy(cpy.dst, cpy.src, cpy.n);
         }
-        subctx.in_memcpys.clear();
-        subctx.out_memcpys.clear();
+        subctx->in_memcpys.clear();
+        subctx->out_memcpys.clear();
     }
 
     return true;
@@ -5725,8 +6066,7 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
 
     ctx->staging_offset = 0;
 
-    ctx->compute_ctx = nullptr;
-    ctx->transfer_ctx = nullptr;
+    ctx->tensor_ctxs.clear();
     ctx->gc.contexts.clear();
 }
 
@@ -6063,15 +6403,20 @@ GGML_CALL static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, g
 
     ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
 
-    if (ctx->transfer_ctx == nullptr) {
+    vk_context transfer_ctx;
+
+    if (ctx->transfer_ctx.expired()) {
         // Initialize new transfer context
-        ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
-        ggml_vk_ctx_begin(ctx->device, ctx->transfer_ctx);
+        transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
+        ctx->transfer_ctx = transfer_ctx;
+        ggml_vk_ctx_begin(ctx->device, transfer_ctx);
+    } else {
+        transfer_ctx = ctx->transfer_ctx.lock();
     }
 
     vk_buffer buf = extra->buffer_gpu.lock();
 
-    ggml_vk_buffer_write_async(ctx->transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size, ctx->staging, ctx->staging_offset);
+    ggml_vk_buffer_write_async(transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size, ctx->staging, ctx->staging_offset);
 }
 
 GGML_CALL static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
@@ -6081,15 +6426,20 @@ GGML_CALL static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, c
 
     ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
 
-    if (ctx->transfer_ctx == nullptr) {
+    vk_context transfer_ctx;
+
+    if (ctx->transfer_ctx.expired()) {
         // Initialize new transfer context
-        ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
-        ggml_vk_ctx_begin(ctx->device, ctx->transfer_ctx);
+        transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
+        ctx->transfer_ctx = transfer_ctx;
+        ggml_vk_ctx_begin(ctx->device, transfer_ctx);
+    } else {
+        transfer_ctx = ctx->transfer_ctx.lock();
     }
 
     vk_buffer buf = extra->buffer_gpu.lock();
 
-    ggml_vk_buffer_read_async(ctx->transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size, ctx->staging, ctx->staging_offset);
+    ggml_vk_buffer_read_async(transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size, ctx->staging, ctx->staging_offset);
 }
 
 GGML_CALL static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) {
@@ -6099,16 +6449,21 @@ GGML_CALL static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, c
         ggml_tensor_extra_gpu * src_extra = (ggml_tensor_extra_gpu *) src->extra;
         ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
 
-        if (ctx->transfer_ctx == nullptr) {
+        vk_context transfer_ctx;
+
+        if (ctx->transfer_ctx.expired()) {
             // Initialize new transfer context
-            ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
-            ggml_vk_ctx_begin(ctx->device, ctx->transfer_ctx);
+            transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
+            ctx->transfer_ctx = transfer_ctx;
+            ggml_vk_ctx_begin(ctx->device, transfer_ctx);
+        } else {
+            transfer_ctx = ctx->transfer_ctx.lock();
         }
 
         vk_buffer src_buf = src_extra->buffer_gpu.lock();
         vk_buffer dst_buf = dst_extra->buffer_gpu.lock();
 
-        ggml_vk_buffer_copy_async(ctx->transfer_ctx, dst_buf, dst_extra->offset + dst->view_offs, src_buf, src_extra->offset + src->view_offs, ggml_nbytes(src));
+        ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, dst_extra->offset + dst->view_offs, src_buf, src_extra->offset + src->view_offs, ggml_nbytes(src));
         return true;
     }
 
@@ -6118,25 +6473,27 @@ GGML_CALL static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, c
 GGML_CALL static void ggml_backend_vk_synchronize(ggml_backend_t backend) {
     VK_LOG_DEBUG("ggml_backend_vk_synchronize()");
     ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
-    if(ctx->transfer_ctx == nullptr) {
+    if(ctx->transfer_ctx.expired()) {
         return;
     }
 
-    ggml_vk_ctx_end(ctx->transfer_ctx);
+    vk_context transfer_ctx = ctx->transfer_ctx.lock();
+
+    ggml_vk_ctx_end(transfer_ctx);
 
-    for (auto& cpy : ctx->transfer_ctx->in_memcpys) {
+    for (auto& cpy : transfer_ctx->in_memcpys) {
         memcpy(cpy.dst, cpy.src, cpy.n);
     }
 
-    ggml_vk_submit(ctx->transfer_ctx, ctx->fence);
+    ggml_vk_submit(transfer_ctx, ctx->fence);
     VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_backend_vk_synchronize waitForFences");
     ctx->device->device.resetFences({ ctx->fence });
 
-    for (auto& cpy : ctx->transfer_ctx->out_memcpys) {
+    for (auto& cpy : transfer_ctx->out_memcpys) {
         memcpy(cpy.dst, cpy.src, cpy.n);
     }
 
-    ctx->transfer_ctx = nullptr;
+    ctx->transfer_ctx.reset();
 }
 
 static bool ggml_vk_is_empty(ggml_tensor * node) {
@@ -6159,8 +6516,11 @@ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backen
         last_node -= 1;
     }
 
+    // Reserve tensor context space for all nodes
+    ctx->tensor_ctxs.resize(cgraph->n_nodes);
+
     for (int i = 0; i < cgraph->n_nodes; i++) {
-        ggml_vk_build_graph(ctx,cgraph->nodes[i], i == last_node);
+        ggml_vk_build_graph(ctx, cgraph->nodes[i], i, i == last_node);
     }
 
     for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -6170,13 +6530,17 @@ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backen
             continue;
         }
 
-        bool ok = ggml_vk_compute_forward(ctx, node);
+        bool ok = ggml_vk_compute_forward(ctx, node, i);
         if (!ok) {
-            fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
+            if (node->op == GGML_OP_UNARY) {
+                std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
+            } else {
+                std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
+            }
         }
 #ifdef GGML_VULKAN_CHECK_RESULTS
         else {
-            ggml_vk_check_results_1(ctx, node);
+            ggml_vk_check_results_1(node);
         }
 #endif
         GGML_ASSERT(ok);
@@ -6196,8 +6560,10 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
                 case GGML_UNARY_OP_GELU:
+                case GGML_UNARY_OP_GELU_QUICK:
                 case GGML_UNARY_OP_SILU:
                 case GGML_UNARY_OP_RELU:
+                case GGML_UNARY_OP_TANH:
                     return ggml_is_contiguous(op->src[0]);
                 default:
                     return false;
@@ -6270,11 +6636,11 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
                 }
                 return false;
             } break;
-        // case GGML_OP_REPEAT:
-        //     {
-        //         ggml_type src0_type = op->src[0]->type;
-        //         return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
-        //     } break;
+        case GGML_OP_REPEAT:
+            {
+                ggml_type src0_type = op->src[0]->type;
+                return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
+            } break;
         case GGML_OP_ROPE:
             return ggml_is_contiguous(op->src[0]);
         case GGML_OP_NONE:
@@ -6283,18 +6649,25 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
         case GGML_OP_PERMUTE:
         case GGML_OP_TRANSPOSE:
         case GGML_OP_NORM:
+        case GGML_OP_GROUP_NORM:
+        case GGML_OP_RMS_NORM:
         case GGML_OP_ADD:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
-        case GGML_OP_RMS_NORM:
+        case GGML_OP_CONCAT:
+        case GGML_OP_UPSCALE:
         case GGML_OP_SCALE:
         case GGML_OP_SQR:
         case GGML_OP_CLAMP:
+        case GGML_OP_PAD:
         case GGML_OP_CONT:
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_SOFT_MAX:
         case GGML_OP_ARGSORT:
         case GGML_OP_SUM_ROWS:
+        case GGML_OP_IM2COL:
+        case GGML_OP_TIMESTEP_EMBEDDING:
+        case GGML_OP_LEAKY_RELU:
             return true;
         default:
             return false;
@@ -6509,10 +6882,12 @@ static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * d
     }
 }
 
-static void ggml_vk_print_tensor(ggml_backend_vk_context * ctx, const ggml_tensor * tensor, const char * name) {
+static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name) {
     void * tensor_data = tensor->data;
 
-    if (ggml_backend_buffer_is_vk(tensor->buffer)) {
+    const bool is_gpu = tensor->buffer != nullptr && ggml_backend_buffer_is_vk(tensor->buffer);
+
+    if (is_gpu) {
         const size_t tensor_size = ggml_nbytes(tensor);
         tensor_data = malloc(tensor_size);
 
@@ -6533,13 +6908,10 @@ static void ggml_vk_print_tensor(ggml_backend_vk_context * ctx, const ggml_tenso
     std::cerr << std::endl << "Result:" << std::endl;
     ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
     std::cerr << std::endl;
-    std::cerr << std::endl << "Result:" << std::endl;
-    ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 1, 0);
-    std::cerr << std::endl;
     std::vector<const ggml_tensor *> done;
     ggml_vk_print_graph_origin(tensor, done);
 
-    if (ggml_backend_buffer_is_vk(tensor->buffer)) {
+    if (is_gpu) {
         free(tensor_data);
     }
 }
@@ -6548,8 +6920,8 @@ void * comp_result;
 size_t comp_size;
 size_t comp_nb[GGML_MAX_DIMS];
 size_t check_counter = 0;
-static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * tensor) {
-        if (tensor->op == GGML_OP_TRANSPOSE) {
+static void ggml_vk_check_results_0(ggml_tensor * tensor) {
+    if (tensor->op == GGML_OP_TRANSPOSE) {
         return;
     }
 
@@ -6565,7 +6937,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
     ggml_tensor * src2 = tensor->src[2];
 
     struct ggml_init_params iparams = {
-        /*.mem_size   =*/ 1024*1024*1024,
+        /*.mem_size   =*/ 2ul*1024ul*1024ul*1024ul,
         /*.mem_buffer =*/ NULL,
         /*.no_alloc   =*/ false,
     };
@@ -6624,7 +6996,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
         }
 
         if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
-            ggml_vk_print_tensor(ctx, src0, "src0");
+            ggml_vk_print_tensor(src0, "src0");
         }
     }
     if (src1 != nullptr) {
@@ -6666,23 +7038,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
         }
 
         if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
-            ggml_vk_print_tensor(ctx, src1, "src1");
-            std::cerr << "TENSOR CHECK: " << ggml_op_name(src1_clone->op) << " (check " << check_counter << ")" << std::endl;
-            std::cerr << "src1_clone=" << tensor << " src1_clone->type: " << ggml_type_name(src1_clone->type) << " ne0=" << src1_clone->ne[0] << " nb0=" << src1_clone->nb[0] << " ne1=" << src1_clone->ne[1] << " nb1=" << src1_clone->nb[1] << " ne2=" << src1_clone->ne[2] << " nb2=" << src1_clone->nb[2] << " ne3=" << src1_clone->ne[3] << " nb3=" << src1_clone->nb[3] << std::endl;
-            if (src1->src[0] != nullptr) {
-                std::cerr << "src1->src[0]=" << src1->src[0] << " op=" << ggml_op_name(src1->src[0]->op) << " type=" << ggml_type_name(src1->src[0]->type) << " ne0=" << src1->src[0]->ne[0] << " nb0=" << src1->src[0]->nb[0] << " ne1=" << src1->src[0]->ne[1] << " nb1=" << src1->src[0]->nb[1] << " ne2=" << src1->src[0]->ne[2] << " nb2=" << src1->src[0]->nb[2] << " ne3=" << src1->src[0]->ne[3] << " nb3=" << src1->src[0]->nb[3] << std::endl;
-            }
-            if (src1->src[1] != nullptr) {
-                std::cerr << "src1->src[1]=" << src1->src[1] << " op=" << ggml_op_name(src1->src[1]->op) << " type=" << ggml_type_name(src1->src[1]->type) << " ne0=" << src1->src[1]->ne[0] << " nb0=" << src1->src[1]->nb[0] << " ne1=" << src1->src[1]->ne[1] << " nb1=" << src1->src[1]->nb[1] << " ne2=" << src1->src[1]->ne[2] << " nb2=" << src1->src[1]->nb[2] << " ne3=" << src1->src[1]->ne[3] << " nb3=" << src1->src[1]->nb[3] << std::endl;
-            }
-            std::cerr << std::endl << "Result:" << std::endl;
-            ggml_vk_print_tensor_area(src1_clone, src1_clone->data, 5, 5, 0, 0);
-            std::cerr << std::endl;
-            std::cerr << std::endl << "Result:" << std::endl;
-            ggml_vk_print_tensor_area(src1_clone, src1_clone->data, 5, 5, 1, 0);
-            std::cerr << std::endl;
-            std::vector<const ggml_tensor *> done;
-            ggml_vk_print_graph_origin(src1_clone, done);
+            ggml_vk_print_tensor(src1, "src1");
         }
     }
     if (src2 != nullptr) {
@@ -6724,23 +7080,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
         }
 
         if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
-            ggml_vk_print_tensor(ctx, src2, "src2");
-            std::cerr << "TENSOR CHECK: " << ggml_op_name(src2_clone->op) << " (check " << check_counter << ")" << std::endl;
-            std::cerr << "src2_clone=" << tensor << " src2_clone->type: " << ggml_type_name(src2_clone->type) << " ne0=" << src2_clone->ne[0] << " nb0=" << src2_clone->nb[0] << " ne1=" << src2_clone->ne[1] << " nb1=" << src2_clone->nb[1] << " ne2=" << src2_clone->ne[2] << " nb2=" << src2_clone->nb[2] << " ne3=" << src2_clone->ne[3] << " nb3=" << src2_clone->nb[3] << std::endl;
-            if (src2->src[0] != nullptr) {
-                std::cerr << "src2->src[0]=" << src2->src[0] << " op=" << ggml_op_name(src2->src[0]->op) << " type=" << ggml_type_name(src2->src[0]->type) << " ne0=" << src2->src[0]->ne[0] << " nb0=" << src2->src[0]->nb[0] << " ne1=" << src2->src[0]->ne[1] << " nb1=" << src2->src[0]->nb[1] << " ne2=" << src2->src[0]->ne[2] << " nb2=" << src2->src[0]->nb[2] << " ne3=" << src2->src[0]->ne[3] << " nb3=" << src2->src[0]->nb[3] << std::endl;
-            }
-            if (src2->src[1] != nullptr) {
-                std::cerr << "src2->src[1]=" << src2->src[1] << " op=" << ggml_op_name(src2->src[1]->op) << " type=" << ggml_type_name(src2->src[1]->type) << " ne0=" << src2->src[1]->ne[0] << " nb0=" << src2->src[1]->nb[0] << " ne1=" << src2->src[1]->ne[1] << " nb1=" << src2->src[1]->nb[1] << " ne2=" << src2->src[1]->ne[2] << " nb2=" << src2->src[1]->nb[2] << " ne3=" << src2->src[1]->ne[3] << " nb3=" << src2->src[1]->nb[3] << std::endl;
-            }
-            std::cerr << std::endl << "Result:" << std::endl;
-            ggml_vk_print_tensor_area(src2_clone, src2_clone->data, 5, 5, 0, 0);
-            std::cerr << std::endl;
-            std::cerr << std::endl << "Result:" << std::endl;
-            ggml_vk_print_tensor_area(src2_clone, src2_clone->data, 5, 5, 1, 0);
-            std::cerr << std::endl;
-            std::vector<const ggml_tensor *> done;
-            ggml_vk_print_graph_origin(src2_clone, done);
+            ggml_vk_print_tensor(src2, "src2");
         }
     }
 
@@ -6752,16 +7092,24 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
         tensor_clone = ggml_mul(ggml_ctx, src0_clone, src1_clone);
     } else if (tensor->op == GGML_OP_DIV) {
         tensor_clone = ggml_div(ggml_ctx, src0_clone, src1_clone);
+    } else if (tensor->op == GGML_OP_CONCAT) {
+        tensor_clone = ggml_concat(ggml_ctx, src0_clone, src1_clone, *(int *)tensor->op_params);
+    } else if (tensor->op == GGML_OP_UPSCALE) {
+        tensor_clone = ggml_upscale_ext(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
     } else if (tensor->op == GGML_OP_SCALE) {
         tensor_clone = ggml_scale(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0]);
     } else if (tensor->op == GGML_OP_SQR) {
         tensor_clone = ggml_sqr(ggml_ctx, src0_clone);
     } else if (tensor->op == GGML_OP_CLAMP) {
         tensor_clone = ggml_clamp(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
+    } else if (tensor->op == GGML_OP_PAD) {
+        tensor_clone = ggml_pad(ggml_ctx, src0_clone, tensor->ne[0] - src0_clone->ne[0], tensor->ne[1] - src0_clone->ne[1], tensor->ne[2] - src0_clone->ne[2], tensor->ne[3] - src0_clone->ne[3]);
     } else if (tensor->op == GGML_OP_ADD) {
         tensor_clone = ggml_add(ggml_ctx, src0_clone, src1_clone);
     } else if (tensor->op == GGML_OP_NORM) {
         tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
+    } else if (tensor->op == GGML_OP_GROUP_NORM) {
+        tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params);
     } else if (tensor->op == GGML_OP_RMS_NORM) {
         tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
     } else if (tensor->op == GGML_OP_SOFT_MAX) {
@@ -6777,12 +7125,12 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
         const int mode        = ((int32_t *) tensor->op_params)[2];
         //const int n_ctx_ggml       = ((int32_t *) tensor->op_params)[3];
         const int n_ctx_orig_ggml  = ((int32_t *) tensor->op_params)[4];
-        float freq_base       = ((float *)   tensor->op_params)[5];
-        float freq_scale      = ((float *)   tensor->op_params)[6];
-        float ext_factor      = ((float *)   tensor->op_params)[7];
-        float attn_factor     = ((float *)   tensor->op_params)[8];
-        float beta_fast       = ((float *)   tensor->op_params)[9];
-        float beta_slow       = ((float *)   tensor->op_params)[10];
+        const float freq_base       = ((float *) tensor->op_params)[5];
+        const float freq_scale      = ((float *) tensor->op_params)[6];
+        const float ext_factor      = ((float *) tensor->op_params)[7];
+        const float attn_factor     = ((float *) tensor->op_params)[8];
+        const float beta_fast       = ((float *) tensor->op_params)[9];
+        const float beta_slow       = ((float *) tensor->op_params)[10];
         tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
     } else if (tensor->op == GGML_OP_UNARY) {
         switch (ggml_get_unary_op(tensor)) {
@@ -6792,9 +7140,15 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
         case GGML_UNARY_OP_GELU:
             tensor_clone = ggml_gelu(ggml_ctx, src0_clone);
             break;
+        case GGML_UNARY_OP_GELU_QUICK:
+            tensor_clone = ggml_gelu_quick(ggml_ctx, src0_clone);
+            break;
         case GGML_UNARY_OP_RELU:
             tensor_clone = ggml_relu(ggml_ctx, src0_clone);
             break;
+        case GGML_UNARY_OP_TANH:
+            tensor_clone = ggml_tanh(ggml_ctx, src0_clone);
+            break;
         default:
             std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
             GGML_ABORT("fatal error");
@@ -6823,6 +7177,23 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
         tensor_clone = ggml_argsort(ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params);
     } else if (tensor->op == GGML_OP_SUM_ROWS) {
         tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone);
+    } else if (tensor->op == GGML_OP_IM2COL) {
+        const int32_t s0 = tensor->op_params[0];
+        const int32_t s1 = tensor->op_params[1];
+        const int32_t p0 = tensor->op_params[2];
+        const int32_t p1 = tensor->op_params[3];
+        const int32_t d0 = tensor->op_params[4];
+        const int32_t d1 = tensor->op_params[5];
+
+        const bool is_2D = tensor->op_params[6] == 1;
+        tensor_clone = ggml_im2col(ggml_ctx, src0_clone, src1_clone, s0, s1, p0, p1, d0, d1, is_2D, tensor->type);
+    } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) {
+        const int32_t dim = tensor->op_params[0];
+        const int32_t max_period = tensor->op_params[1];
+        tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period);
+    } else if (tensor->op == GGML_OP_LEAKY_RELU) {
+        const float * op_params = (const float *)tensor->op_params;
+        tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false);
     } else {
         std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
         GGML_ABORT("fatal error");
@@ -6834,7 +7205,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
     ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8);
 
     if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
-        ggml_vk_print_tensor(ctx, tensor_clone, "tensor_clone");
+        ggml_vk_print_tensor(tensor_clone, "tensor_clone");
     }
 
     comp_size = ggml_nbytes(tensor_clone);
@@ -6851,9 +7222,11 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
     }
 
     ggml_free(ggml_ctx);
+
+    VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")");
 }
 
-static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_tensor * tensor) {
+static void ggml_vk_check_results_1(ggml_tensor * tensor) {
     if (tensor->op == GGML_OP_TRANSPOSE) {
         return;
     }
@@ -6977,11 +7350,6 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_tensor *
         std::cerr << std::endl << "Correct:" << std::endl;
         ggml_vk_print_tensor_area(tensor, comp_result, 5, 5, 0, 0);
         std::cerr << std::endl;
-        std::cerr << std::endl << "Result:" << std::endl;
-        ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 1, 0);
-        std::cerr << std::endl << "Correct:" << std::endl;
-        ggml_vk_print_tensor_area(tensor, comp_result, 5, 5, 1, 0);
-        std::cerr << std::endl;
         std::vector<const ggml_tensor *> done;
         ggml_vk_print_graph_origin(tensor, done);
     }
@@ -7018,5 +7386,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_tensor *
     if (ggml_backend_buffer_is_vk(tensor->buffer)) {
         free(tensor_data);
     }
+
+    VK_LOG_DEBUG("END ggml_vk_check_results_1(" << tensor->name << ")");
 }
 #endif
diff --git a/src/ggml_vk_generate_shaders.py b/src/ggml_vk_generate_shaders.py
deleted file mode 100644 (file)
index 7ebf600..0000000
+++ /dev/null
@@ -1,3149 +0,0 @@
-#!/usr/bin/env python
-
-import logging
-import argparse
-import asyncio
-import os
-import sys
-from tempfile import gettempdir, NamedTemporaryFile
-
-logger = logging.getLogger("ggml-vk-generate-shaders")
-
-shader_f32 = """
-#define FLOAT_TYPE float
-"""
-shader_f16 = """
-#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
-#define FLOAT_TYPE float16_t
-"""
-shader_int8_ext = """
-#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
-"""
-
-# Type-specific defines
-shader_f32_defines = """
-#define QUANT_K 1
-#define QUANT_R 1
-
-#define A_TYPE float
-"""
-shader_f16_defines = """
-#define QUANT_K 1
-#define QUANT_R 1
-
-#define A_TYPE float16_t
-"""
-shader_q4_0_defines = """
-#define QUANT_K 32
-#define QUANT_R 2
-
-struct block_q4_0
-{
-    float16_t d;
-    uint8_t qs[16];
-};
-
-#define A_TYPE block_q4_0
-"""
-shader_q4_1_defines = """
-#define QUANT_K 32
-#define QUANT_R 2
-
-struct block_q4_1
-{
-    float16_t d;
-    float16_t m;
-    uint8_t qs[16];
-};
-
-#define A_TYPE block_q4_1
-"""
-shader_q5_0_defines = """
-#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
-#define QUANT_K 32
-#define QUANT_R 2
-
-struct block_q5_0
-{
-    float16_t d;
-    uint16_t qh[2];
-    uint8_t qs[16];
-};
-
-#define A_TYPE block_q5_0
-"""
-shader_q5_1_defines = """
-#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
-#define QUANT_K 32
-#define QUANT_R 2
-
-struct block_q5_1
-{
-    float16_t d;
-    float16_t m;
-    uint qh;
-    uint8_t qs[16];
-};
-
-#define A_TYPE block_q5_1
-"""
-shader_q8_0_defines = """
-#define QUANT_K 32
-#define QUANT_R 1
-
-struct block_q8_0
-{
-    float16_t d;
-    int8_t qs[32];
-};
-
-#define A_TYPE block_q8_0
-"""
-
-# K-quants
-shader_q2_K_defines = """
-#define QUANT_K 256
-
-struct block_q2_K
-{
-    uint8_t scales[QUANT_K/16];
-    uint8_t qs[QUANT_K/4];
-    f16vec2 d;
-};
-
-#define A_TYPE block_q2_K
-"""
-shader_q3_K_defines = """
-#define QUANT_K 256
-
-struct block_q3_K
-{
-    uint8_t hmask[QUANT_K/8];
-    uint8_t qs[QUANT_K/4];
-    uint8_t scales[12];
-    float16_t d;
-};
-
-#define A_TYPE block_q3_K
-"""
-shader_q4_K_defines = """
-#define QUANT_K 256
-
-struct block_q4_K
-{
-    f16vec2 d;
-    uint8_t scales[3*QUANT_K/64];
-    uint8_t qs[QUANT_K/2];
-};
-
-#define A_TYPE block_q4_K
-"""
-shader_q5_K_defines = """
-#define QUANT_K 256
-
-struct block_q5_K
-{
-    f16vec2 d;
-    uint8_t scales[12];
-    uint8_t qh[QUANT_K/8];
-    uint8_t qs[QUANT_K/2];
-};
-
-#define A_TYPE block_q5_K
-"""
-shader_q6_K_defines = """
-#define QUANT_K 256
-
-struct block_q6_K
-{
-    uint8_t ql[QUANT_K/2];
-    uint8_t qh[QUANT_K/4];
-    int8_t scales[QUANT_K/16];
-    float16_t d;
-};
-
-#define A_TYPE block_q6_K
-"""
-
-# Dequant functions
-shader_float_dequant_func = """
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
-    return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
-}
-"""
-
-shader_q4_0_dequant_func = """
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
-    const float d = float(data_a[a_offset + ib].d);
-    const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
-    return (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
-}
-"""
-
-shader_q4_1_dequant_func = """
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
-    const float d = float(data_a[a_offset + ib].d);
-    const float m = float(data_a[a_offset + ib].m);
-    const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
-    return vec2(vui & 0xF, vui >> 4) * d + m;
-}
-"""
-
-shader_q5_0_dequant_func = """
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
-    const float d = float(data_a[a_offset + ib].d);
-    const uint uint_qh = uint(data_a[a_offset + ib].qh[1]) << 16 | data_a[a_offset + ib].qh[0];
-    const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
-    const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
-    return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
-}
-"""
-
-shader_q5_1_dequant_func = """
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
-    const float d = float(data_a[a_offset + ib].d);
-    const float m = float(data_a[a_offset + ib].m);
-    const uint uint_qh = data_a[a_offset + ib].qh;
-    const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
-    const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
-    return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
-}
-"""
-
-shader_q8_0_dequant_func = """
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
-    const float d = float(data_a[a_offset + ib].d);
-    return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])) * d;
-}
-"""
-
-# MULMAT
-
-mulmat_head = """#version 450
-
-#extension GL_EXT_control_flow_attributes : enable
-#extension GL_EXT_shader_16bit_storage : require
-
-#ifdef MUL_MAT_ID
-#extension GL_EXT_buffer_reference2 : require
-#extension GL_EXT_nonuniform_qualifier : require
-#extension GL_EXT_scalar_block_layout : require
-#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
-
-#define EXPERT_COUNT 8
-#endif
-
-#ifndef LOAD_VEC_A
-#define LOAD_VEC_A 1
-#endif
-#ifndef LOAD_VEC_B
-#define LOAD_VEC_B 1
-#endif
-"""
-
-mulmat_body1 = """
-layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
-
-#ifdef MUL_MAT_ID
-layout (binding = 3) readonly buffer IDS {int data_ids[];};
-#endif
-
-layout (push_constant) uniform parameter
-{
-    uint M;
-    uint N;
-    uint K;
-    uint stride_a;
-    uint stride_b;
-    uint stride_d;
-    uint k_split;
-
-    uint ne02;
-    uint ne12;
-    uint broadcast2;
-    uint broadcast3;
-
-    uint batch_stride_a;
-    uint batch_stride_b;
-    uint batch_stride_d;
-
-#ifdef MUL_MAT_ID
-    uint expert_stride_a;
-    uint expert_stride_b0;
-    uint expert_stride_b1;
-    uint expert_stride_d;
-
-    uint ids_stride;
-
-    uint n_as;
-    uint nei0;
-    uint nei1;
-    uint nbi1;
-    uint ne11;
-#endif
-} p;
-
-layout (constant_id = 1) const uint BM = 64;
-layout (constant_id = 2) const uint BN = 64;
-layout (constant_id = 3) const uint BK = 16;  // Assumed to be 32 if working with a quant
-layout (constant_id = 4) const uint WM = 32;
-layout (constant_id = 5) const uint WN = 32;
-layout (constant_id = 6) const uint WMITER = 2;
-layout (constant_id = 7) const uint TM = 4;
-layout (constant_id = 8) const uint TN = 2;
-layout (constant_id = 9) const uint WARP = 32;
-
-shared FLOAT_TYPE buf_a[BM * (BK+1)];
-shared FLOAT_TYPE buf_b[BN * (BK+1)];
-
-#ifdef MUL_MAT_ID
-shared u8vec2 rowids[2048];
-#endif
-
-void main() {
-#ifdef MUL_MAT_ID
-    const uint batch_idx = gl_GlobalInvocationID.z / p.n_as;
-    const uint expert_idx = gl_GlobalInvocationID.z % p.n_as;
-#else
-    const uint batch_idx = gl_GlobalInvocationID.z;
-#endif
-
-    const uint i13 = batch_idx / p.ne12;
-    const uint i12 = batch_idx % p.ne12;
-
-    const uint i03 = i13 / p.broadcast3;
-    const uint i02 = i12 / p.broadcast2;
-
-    const uint batch_idx_a = i03 * p.ne02 + i02;
-
-    const uint blocks_m = (p.M + BM - 1) / BM;
-    const uint ir = gl_WorkGroupID.x % blocks_m;
-    const uint ik = gl_WorkGroupID.x / blocks_m;
-    const uint ic = gl_WorkGroupID.y;
-
-    const uint warp_i = gl_LocalInvocationID.x / WARP;
-    const uint warp_r = warp_i % (BM / WM);
-    const uint warp_c = warp_i / (BM / WM);
-
-    const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
-    const uint WSUBM = WM / WMITER;
-    const uint WSUBN = WN / WNITER;
-
-    const uint tiw = gl_LocalInvocationID.x % WARP;
-    const uint tiwr = tiw % (WSUBM / TM);
-    const uint tiwc = tiw / (WSUBM / TM);
-
-    const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
-    const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
-    const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
-    const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
-
-    const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK;
-    const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
-
-#ifdef MUL_MAT_ID
-    uint _ne1 = 0;
-    for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
-        for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
-            if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
-                rowids[_ne1] = u8vec2(ii0, ii1);
-                _ne1++;
-            }
-        }
-    }
-
-    const u8vec2 id = rowids[ir * BN + ic];
-#endif
-
-    const uint start_k = ik * p.k_split;
-    const uint end_k = min(p.K, (ik + 1) * p.k_split);
-
-    uint pos_a = (
-#ifdef MUL_MAT_ID
-        expert_idx * p.expert_stride_a +
-#endif
-        batch_idx_a * p.batch_stride_a + ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
-    uint pos_b = (
-#ifdef MUL_MAT_ID
-        id.y * p.expert_stride_b1 +
-        (id.x % p.ne11) * p.expert_stride_b0 +
-#endif
-        batch_idx * p.batch_stride_b +
-        ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
-
-    float sums[WMITER * TM * WNITER * TN];
-    FLOAT_TYPE cache_a[WMITER * TM];
-    FLOAT_TYPE cache_b[WNITER * TN];
-
-    [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
-        sums[i] = 0.0f;
-    }
-
-    [[unroll]] for (uint block = start_k; block < end_k; block += BK) {
-        [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {"""
-
-mulmat_load_scalar = """
-#if LOAD_VEC_A == 8
-            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-            buf_a[buf_idx    ] = FLOAT_TYPE(data_a[idx][0].x);
-            buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
-            buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
-            buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w);
-            buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x);
-            buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y);
-            buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z);
-            buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
-#elif LOAD_VEC_A == 4
-            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-            buf_a[buf_idx    ] = FLOAT_TYPE(data_a[idx].x);
-            buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
-            buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
-            buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
-#else
-            if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
-                buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
-            } else {
-                buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(0.0f);
-            }
-#endif
-"""
-
-mulmat_load_q4_0 = """
-            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
-
-            const uint ib = idx / 16;
-            const uint iqs = idx & 0xF;
-
-            const float d = float(data_a[ib].d);
-            const uint vui = uint(data_a[ib].qs[iqs]);
-            const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
-
-            buf_a[buf_idx     ] = FLOAT_TYPE(v.x);
-            buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);"""
-
-mulmat_load_q4_1 = """
-            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
-
-            const uint ib = idx / 16;
-            const uint iqs = idx & 0xF;
-
-            const float d = float(data_a[ib].d);
-            const float m = float(data_a[ib].m);
-            const uint vui = uint(data_a[ib].qs[iqs]);
-            const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m;
-
-            buf_a[buf_idx     ] = FLOAT_TYPE(v.x);
-            buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);"""
-
-mulmat_load_q5_0 = """
-            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
-
-            const uint ib = idx / 16;
-            const uint iqs = idx & 0xF;
-
-            const float d = float(data_a[ib].d);
-            const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
-            const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
-            const uint vui = uint(data_a[ib].qs[iqs]);
-            const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
-
-            buf_a[buf_idx     ] = FLOAT_TYPE(v.x);
-            buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);"""
-
-mulmat_load_q5_1 = """
-            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
-
-            const uint ib = idx / 16;
-            const uint iqs = idx & 0xF;
-
-            const float d = float(data_a[ib].d);
-            const float m = float(data_a[ib].m);
-            const uint uint_qh = data_a[ib].qh;
-            const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
-            const uint vui = uint(data_a[ib].qs[iqs]);
-            const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
-
-            buf_a[buf_idx     ] = FLOAT_TYPE(v.x);
-            buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);"""
-
-mulmat_load_q8_0 = """
-            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-
-            const uint ib = idx / 16;
-            const uint iqs = (idx & 0xF) * 2;
-
-            const float d = float(data_a[ib].d);
-            const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d;
-
-            buf_a[buf_idx    ] = FLOAT_TYPE(v.x);
-            buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);"""
-
-
-mulmat_load_q2_K = """
-            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-
-            const uint ib = idx / 128;                         // 2 values per idx
-            const uint iqs = idx % 128;                        // 0..127
-
-            const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
-            const uint scalesi = iqs / 8;                      // 0..15
-            const uint qsshift = ((iqs % 64) / 16) * 2;        // 0,2,4,6
-
-            const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
-            const uint scales = data_a[ib].scales[scalesi];
-            const vec2 d = vec2(data_a[ib].d);
-
-            const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
-
-            buf_a[buf_idx    ] = FLOAT_TYPE(v.x);
-            buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);"""
-
-mulmat_load_q3_K = """
-            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-
-            const uint ib = idx / 128;                   // 2 values per idx
-            const uint iqs = idx % 128;                  // 0..127
-
-            const uint n = iqs / 64;                     // 0,1
-            const uint qsi = n * 32 + (iqs % 16) * 2;    // 0,2,4..62
-            const uint hmi =          (iqs % 16) * 2;    // 0,2,4..30
-            const uint j = (iqs % 64) / 4;               // 0..3
-            const uint is = iqs / 8;                     // 0..15
-            const uint halfsplit = ((iqs % 64) / 16);    // 0,1,2,3
-            const uint qsshift = halfsplit * 2;          // 0,2,4,6
-            const uint m = 1 << (4 * n + halfsplit);     // 1,2,4,8,16,32,64,128
-
-            const int8_t us = int8_t(is <  4 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+8] >> 0) & 3) << 4) :
-                                    is <  8 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+4] >> 2) & 3) << 4) :
-                                    is < 12 ? (data_a[ib].scales[is-8] >>  4) | (((data_a[ib].scales[is+0] >> 4) & 3) << 4) :
-                                            (data_a[ib].scales[is-8] >>  4) | (((data_a[ib].scales[is-4] >> 6) & 3) << 4));
-            const float dl = float(data_a[ib].d) * float(us - 32);
-
-            buf_a[buf_idx    ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi    ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi    ] & m) != 0) ? 0 : 4)));
-            buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));"""
-
-mulmat_load_q4_K = """
-            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-
-            const uint ib = idx / 128;                 // 2 values per idx
-            const uint iqs = idx % 128;                // 0..127
-
-            const uint n = iqs / 32;                   // 0,1,2,3
-            const uint b = (iqs % 32) / 16;            // 0,1
-            const uint is = 2 * n + b;                 // 0..7
-            const uint qsi = n * 32 + (iqs % 16) * 2;  // 0,2,4..126
-
-            const vec2 loadd = vec2(data_a[ib].d);
-
-            uint8_t sc;
-            uint8_t mbyte;
-            if (is < 4) {
-                sc    = uint8_t(data_a[ib].scales[is    ] & 63);
-                mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
-            } else {
-                sc    = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
-                mbyte = uint8_t((data_a[ib].scales[is + 4] >>  4) | ((data_a[ib].scales[is    ] >> 6) << 4));
-            }
-            const float d = loadd.x * sc;
-            const float m = loadd.y * mbyte;
-
-            buf_a[buf_idx    ] = FLOAT_TYPE(d * float((data_a[ib].qs[qsi    ] >> (b * 4)) & 0xF) - m);
-            buf_a[buf_idx + 1] = FLOAT_TYPE(d * float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) - m);"""
-
-mulmat_load_q5_K = """
-            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-
-            const uint ib = idx / 128;                 // 2 values per idx
-            const uint iqs = idx % 128;                // 0..127
-
-            const uint n = iqs / 32;                   // 0,1,2,3
-            const uint b = (iqs % 32) / 16;            // 0,1
-            const uint is = 2 * n + b;                 // 0..7
-            const uint qsi = n * 32 + (iqs % 16) * 2;  // 0,2,4..126
-            const uint qhi = (iqs % 16) * 2;           // 0,2,4..30
-
-            const uint8_t hm = uint8_t(1 << (iqs / 16));
-
-            const vec2 loadd = vec2(data_a[ib].d);
-
-            uint8_t sc;
-            uint8_t mbyte;
-            if (is < 4) {
-                sc    = uint8_t(data_a[ib].scales[is    ] & 63);
-                mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
-            } else {
-                sc    = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
-                mbyte = uint8_t((data_a[ib].scales[is + 4] >>  4) | ((data_a[ib].scales[is    ] >> 6) << 4));
-            }
-            const float d = loadd.x * sc;
-            const float m = loadd.y * mbyte;
-
-            buf_a[buf_idx    ] = FLOAT_TYPE(d * (float((data_a[ib].qs[qsi    ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi    ] & hm) != 0 ? 16 : 0)) - m);
-            buf_a[buf_idx + 1] = FLOAT_TYPE(d * (float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0)) - m);"""
-
-mulmat_load_q6_K = """
-            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-
-            const uint ib = idx / 128;                  // 2 values per idx
-            const uint iqs = idx % 128;                 // 0..127
-
-            const uint n = iqs / 64;                    // 0,1
-            const uint b = (iqs % 64) / 32;             // 0,1
-            const uint is_b = (iqs % 16) / 8;           // 0,1
-            const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
-            const uint is = 8 * n + qhshift + is_b;     // 0..15
-            const uint qsi = n * 64 + (iqs % 32) * 2;   // 0,2,4..126
-            const uint qhi = n * 32 + (iqs % 16) * 2;   // 0,2,4..62
-
-            const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]);
-
-            buf_a[buf_idx    ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi    ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi    ] >> qhshift) & 3) << 4)) - 32));
-            buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));"""
-
-mulmat_body2 = """
-        }
-        [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
-#if LOAD_VEC_B == 8
-            const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
-            const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B;
-            buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x);
-            buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y);
-            buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z);
-            buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w);
-            buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x);
-            buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y);
-            buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z);
-            buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
-#elif LOAD_VEC_B == 4
-            const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
-            const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B;
-            buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
-            buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
-            buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
-            buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
-#else
-            if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
-                buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
-            } else {
-                buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f);
-            }
-#endif
-        }
-
-        barrier();
-
-        pos_a += BK / LOAD_VEC_A;
-        pos_b += BK / LOAD_VEC_B;
-
-        for (uint i = 0; i < BK; i++) {
-            // Load from shared into cache
-            [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
-                [[unroll]] for (uint j = 0; j < TM; j++) {
-                    cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
-                }
-            }
-            [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
-                [[unroll]] for (uint j = 0; j < TN; j++) {
-                    cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
-                }
-            }
-
-            [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
-                [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
-                    [[unroll]] for (uint cc = 0; cc < TN; cc++) {
-                        [[unroll]] for (uint cr = 0; cr < TM; cr++) {
-                            sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]);
-                        }
-                    }
-                }
-            }
-        }
-
-        barrier();
-    }
-
-    const uint dr = ir * BM + warp_r * WM;
-    const uint dc = ic * BN + warp_c * WN;
-
-    const uint offsets =
-#ifdef MUL_MAT_ID
-            expert_idx * p.expert_stride_d +
-#endif
-            batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
-
-    [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
-        [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
-
-            const uint dr_warp = dr + wsir * WSUBM + tiwr * TM;
-            const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;
-            [[unroll]] for (uint cc = 0; cc < TN; cc++) {
-                [[unroll]] for (uint cr = 0; cr < TM; cr++) {
-                    if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
-                        data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
-                    }
-                }
-            }
-        }
-    }
-}
-"""
-
-mulmat_split_k_reduce_src = """#version 450
-
-#extension GL_EXT_control_flow_attributes : enable
-
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {float data_a[];};
-layout (binding = 1) writeonly buffer D {float data_d[];};
-
-layout (push_constant) uniform parameter {
-    uint ne;
-    uint k_num;
-} p;
-
-void main() {
-    const uint idx = gl_GlobalInvocationID.x;
-
-    if (idx >= p.ne) {
-        return;
-    }
-
-    float result = 0.0f;
-
-    [[unroll]] for (uint i = 0; i < p.k_num; i++) {
-        result += data_a[i * p.ne + idx];
-    }
-
-    data_d[idx] = result;
-}
-"""
-
-# DEQUANT SHADER
-dequant_head = """#version 450
-
-#extension GL_EXT_control_flow_attributes : require
-#extension GL_EXT_shader_16bit_storage : require
-
-layout (push_constant) uniform parameter
-{
-    uint M;
-    uint K;
-    uint stride_a;
-    uint stride_b;
-    uint nel;
-} p;
-"""
-
-dequant_f32_body = """
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {float data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
-    const uint i = gl_GlobalInvocationID.x * 16;
-
-    if (i >= p.nel) {
-        return;
-    }
-
-    [[unroll]] for (uint l = 0; l < 16; l++) {
-        data_b[i + l] = D_TYPE(data_a[i + l]);
-    }
-}
-"""
-
-dequant_q4_0_body = """
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {block_q4_0 data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
-    const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
-
-    const uint tid = gl_LocalInvocationID.x % 64;
-    const uint il  = tid/32;
-    const uint ir  = tid%32;
-    const uint ib = 32*i + ir;
-    if (ib >= p.nel / 32) {
-        return;
-    }
-
-    const uint b_idx = 1024*i + 32*ir + 8*il;
-
-    const float d = float(data_a[ib].d);
-    const float dm = -8.0f * d;
-
-    const uint q_idx = 8*il;
-
-    [[unroll]] for (uint l = 0; l < 8; ++l) {
-        data_b[b_idx + l +  0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + dm);
-        data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >>  4) + dm);
-    }
-}
-"""
-
-dequant_q4_1_body = """
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {block_q4_1 data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
-    const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
-
-    const uint tid = gl_LocalInvocationID.x % 64;
-    const uint il  = tid/32;
-    const uint ir  = tid%32;
-    const uint ib = 32*i + ir;
-    if (ib >= p.nel / 32) {
-        return;
-    }
-
-    const uint b_idx = 1024*i + 32*ir + 8*il;
-
-    const float d = float(data_a[ib].d);
-    const float m = float(data_a[ib].m);
-
-    const uint q_idx = 8*il;
-
-    [[unroll]] for (uint l = 0; l < 8; ++l) {
-        data_b[b_idx + l +  0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + m);
-        data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >>  4) + m);
-    }
-}
-"""
-
-dequant_q5_0_body = """
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {block_q5_0 data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
-    const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
-
-    const uint tid = gl_LocalInvocationID.x % 64;
-    const uint il  = tid/32;
-    const uint ir  = tid%32;
-    const uint ib = 32*i + ir;
-    if (ib >= p.nel / 32) {
-        return;
-    }
-
-    const uint b_idx = 1024*i + 32*ir + 8*il;
-
-    const float d = float(data_a[ib].d);
-    const uint qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
-
-    const uint q_idx = 8*il;
-
-    [[unroll]] for (uint l = 0; l < 8; ++l) {
-        const uint iqs = q_idx + l;
-        const uint vui = uint(data_a[ib].qs[iqs]);
-        data_b[b_idx + l +  0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10)) - 16.0f));
-        data_b[b_idx + l + 16] = D_TYPE(d * (((vui >>  4) | ((qh >> (iqs + 12)) & 0x10)) - 16.0f));
-    }
-}
-"""
-
-dequant_q5_1_body = """
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {block_q5_1 data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
-    const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
-
-    const uint tid = gl_LocalInvocationID.x % 64;
-    const uint il  = tid/32;
-    const uint ir  = tid%32;
-    const uint ib = 32*i + ir;
-    if (ib >= p.nel / 32) {
-        return;
-    }
-
-    const uint b_idx = 1024*i + 32*ir + 8*il;
-
-    const float d = float(data_a[ib].d);
-    const float m = float(data_a[ib].m);
-    const uint qh = data_a[ib].qh;
-
-    const uint q_idx = 8*il;
-
-    [[unroll]] for (uint l = 0; l < 8; ++l) {
-        const uint iqs = q_idx + l;
-        const uint vui = uint(data_a[ib].qs[iqs]);
-        data_b[b_idx + l +  0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10))) + m);
-        data_b[b_idx + l + 16] = D_TYPE(d * (((vui >>  4) | ((qh >> (iqs + 12)) & 0x10))) + m);
-    }
-}
-"""
-
-dequant_q8_0_body = """
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {block_q8_0 data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
-    const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
-
-    const uint tid = gl_LocalInvocationID.x % 64;
-    const uint il  = tid/32;
-    const uint ir  = tid%32;
-    const uint ib = 32*i + ir;
-    if (ib >= p.nel / 32) {
-        return;
-    }
-
-    const uint b_idx = 1024*i + 32*ir + 16*il;
-
-    const float d = float(data_a[ib].d);
-
-    const uint q_idx = 16*il;
-
-    [[unroll]] for (uint l = 0; l < 16; l += 2) {
-        data_b[b_idx + l    ] = D_TYPE(d * data_a[ib].qs[q_idx + l    ]);
-        data_b[b_idx + l + 1] = D_TYPE(d * data_a[ib].qs[q_idx + l + 1]);
-    }
-}
-"""
-
-# K-quants
-dequant_q2_K_body = """
-layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
-    [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
-        const uint i = gl_WorkGroupID.x * 256 + wgy;
-        if (i >= p.M * p.K / QUANT_K) {
-            return;
-        }
-
-        const uint tid = gl_LocalInvocationID.x;
-        const uint ip = tid / 32;
-        const uint il = tid - 32 * ip;
-        const uint is = 8 * ip + il / 16;
-
-        const uint y_idx = i * QUANT_K + 128 * ip + il;
-
-        const uint ql_idx = 32 * ip + il;
-        const uint8_t qs = data_a[i].qs[32 * ip + il];
-
-        FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
-        FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
-        data_b[y_idx +  0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
-        data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
-        data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));
-        data_b[y_idx + 96] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+6] & 0xF) * ((qs >> 6) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+6] >> 4));
-    }
-}
-"""
-dequant_q3_K_body = """
-layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
-    [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
-        const uint i = uint(gl_WorkGroupID.x * 256 + wgy);
-        if (i >= p.M * p.K / QUANT_K) {
-            return;
-        }
-
-        const uint r = gl_LocalInvocationID.x / 4;
-        const uint tid = r / 2;
-        const uint is0 = r % 2;
-        const uint l0 = 16 * is0 + 4 * (gl_LocalInvocationID.x % 4);
-        const uint n = tid / 4;
-        const uint j = tid - 4*n;
-
-        const uint8_t m = uint8_t(1 << (4*n + j));
-        const uint is = 8*n + 2*j + is0;
-        const uint shift = 2*j;
-
-        const int8_t us = int8_t(is <  4 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+8] >> 0) & 3) << 4) :
-                                 is <  8 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+4] >> 2) & 3) << 4) :
-                                 is < 12 ? (data_a[i].scales[is-8] >>  4) | (((data_a[i].scales[is+0] >> 4) & 3) << 4) :
-                                           (data_a[i].scales[is-8] >>  4) | (((data_a[i].scales[is-4] >> 6) & 3) << 4));
-        const FLOAT_TYPE d_all = FLOAT_TYPE(data_a[i].d);
-        const FLOAT_TYPE dl    = d_all * FLOAT_TYPE(us - 32);
-
-        const uint y_idx = i * QUANT_K + 128 * n + 32 * j;
-        const uint qs_idx = 32*n;
-
-        for (uint l = l0; l < l0 + 4; ++l) {
-            data_b[y_idx + l] = D_TYPE(dl * FLOAT_TYPE(int8_t((data_a[i].qs[qs_idx + l] >> shift) & 3) - (((data_a[i].hmask[l] & m) != 0) ? 0 : 4)));
-        }
-    }
-}
-"""
-dequant_q4_K_body = """
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
-    [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
-        const uint i = gl_WorkGroupID.x * 256 + wgy;
-        if (i >= p.M * p.K / QUANT_K) {
-            return;
-        }
-
-        const uint tid = gl_LocalInvocationID.x;
-        const uint il = tid / 8;
-        const uint ir = tid % 8;
-        const uint is = 2 * il;
-        const uint n = 4;
-
-        const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
-        const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
-
-        const uint y_idx = i * QUANT_K + 64 * il + n * ir;
-        const uint qs_idx = 32*il + n * ir;
-
-        uint8_t sc;
-        uint8_t m;
-        if (is < 4) {
-            sc = uint8_t(data_a[i].scales[is] & 63);
-            m  = uint8_t(data_a[i].scales[is + 4] & 63);
-        } else {
-            sc = uint8_t((data_a[i].scales[is + 4] & 0xF) | ((data_a[i].scales[is - 4] >> 6) << 4));
-            m  = uint8_t((data_a[i].scales[is + 4] >>  4) | ((data_a[i].scales[is    ] >> 6) << 4));
-        }
-        const FLOAT_TYPE d1 = dall * sc;
-        const FLOAT_TYPE m1 = dmin * m;
-
-        if (is < 4) {
-            sc = uint8_t(data_a[i].scales[is + 1] & 63);
-            m  = uint8_t(data_a[i].scales[is + 5] & 63);
-        } else {
-            sc = uint8_t((data_a[i].scales[is + 5] & 0xF) | ((data_a[i].scales[is - 3] >> 6) << 4));
-            m  = uint8_t((data_a[i].scales[is + 5] >>  4) | ((data_a[i].scales[is + 1] >> 6) << 4));
-        }
-        const FLOAT_TYPE d2 = dall * sc;
-        const FLOAT_TYPE m2 = dmin * m;
-
-        [[unroll]] for (uint l = 0; l < n; ++l) {
-            data_b[y_idx + l     ] = D_TYPE(d1 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] & 0xF) - m1);
-            data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] >>  4) - m2);
-        }
-    }
-}
-"""
-dequant_q5_K_body = """
-layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
-    [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
-        const uint i = gl_WorkGroupID.x * 256 + wgy;
-        if (i >= p.M * p.K / QUANT_K) {
-            return;
-        }
-
-        const uint tid = gl_LocalInvocationID.x;
-        const uint il = tid / 16;
-        const uint ir = tid % 16;
-        const uint is = 2 * il;
-
-        const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
-        const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
-
-        const uint y_idx = i * QUANT_K + 64 * il + 2 * ir;
-        const uint qs_idx = 32*il + 2 * ir;
-        const uint qh_idx = 2 * ir;
-
-        uint8_t sc;
-        uint8_t m;
-        if (is < 4) {
-            sc = uint8_t(data_a[i].scales[is] & 63);
-            m  = uint8_t(data_a[i].scales[is + 4] & 63);
-        } else {
-            sc = uint8_t((data_a[i].scales[is + 4] & 0xF) | ((data_a[i].scales[is - 4] >> 6) << 4));
-            m  = uint8_t((data_a[i].scales[is + 4] >>  4) | ((data_a[i].scales[is    ] >> 6) << 4));
-        }
-        const FLOAT_TYPE d1 = dall * sc;
-        const FLOAT_TYPE m1 = dmin * m;
-
-        if (is < 4) {
-            sc = uint8_t(data_a[i].scales[is + 1] & 63);
-            m  = uint8_t(data_a[i].scales[is + 5] & 63);
-        } else {
-            sc = uint8_t((data_a[i].scales[is + 5] & 0xF) | ((data_a[i].scales[is - 3] >> 6) << 4));
-            m  = uint8_t((data_a[i].scales[is + 5] >>  4) | ((data_a[i].scales[is + 1] >> 6) << 4));
-        }
-        const FLOAT_TYPE d2 = dall * sc;
-        const FLOAT_TYPE m2 = dmin * m;
-
-        const uint8_t hm1 = uint8_t(1 << (2 * il    ));
-        const uint8_t hm2 = uint8_t(1 << (2 * il + 1));
-        data_b[y_idx     ] = D_TYPE(d1 * FLOAT_TYPE((data_a[i].qs[qs_idx    ] & 0xF) + (((data_a[i].qh[qh_idx    ] & hm1) != 0) ? 16 : 0)) - m1);
-        data_b[y_idx +  1] = D_TYPE(d1 * FLOAT_TYPE((data_a[i].qs[qs_idx + 1] & 0xF) + (((data_a[i].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1);
-        data_b[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((data_a[i].qs[qs_idx    ]  >> 4) + (((data_a[i].qh[qh_idx    ] & hm2) != 0) ? 16 : 0)) - m2);
-        data_b[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((data_a[i].qs[qs_idx + 1]  >> 4) + (((data_a[i].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2);
-    }
-}
-"""
-dequant_q6_K_body = """
-layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
-    [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
-        const uint i = gl_WorkGroupID.x * 256 + wgy;
-        if (i >= p.M * p.K / QUANT_K) {
-            return;
-        }
-        const uint tid = gl_LocalInvocationID.x;
-        const uint ip = tid / 32;
-        const uint il = tid - 32 * ip;
-        const uint is = 8 * ip + il / 16;
-
-        const uint y_idx = i * QUANT_K + 128 * ip + il;
-
-        const uint ql_idx = 64 * ip + il;
-        const uint8_t qh = data_a[i].qh[32 * ip + il];
-
-        const FLOAT_TYPE d = FLOAT_TYPE(data_a[i].d);
-
-        data_b[y_idx +  0] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 0] * (int8_t((data_a[i].ql[ql_idx +  0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32)));
-        data_b[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 2] * (int8_t((data_a[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)));
-        data_b[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 4] * (int8_t((data_a[i].ql[ql_idx +  0] >>  4) | (((qh >> 4) & 3) << 4)) - 32)));
-        data_b[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 6] * (int8_t((data_a[i].ql[ql_idx + 32] >>  4) | (((qh >> 6) & 3) << 4)) - 32)));
-    }
-}
-"""
-
-# Mul Mat Vec
-mul_mat_vec_head = """#version 450
-
-#extension GL_EXT_control_flow_attributes : enable
-#extension GL_EXT_shader_16bit_storage : require
-#extension GL_EXT_shader_8bit_storage : require
-
-#ifdef MUL_MAT_ID
-#define EXPERT_COUNT 8
-#endif
-"""
-
-
-mul_mat_vec_layout = """
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
-#ifdef MUL_MAT_ID
-layout (binding = 3) readonly buffer IDS {int data_ids[];};
-#endif
-
-layout (push_constant) uniform parameter
-{
-    uint ncols;
-    uint stride_a;
-    uint stride_b;
-    uint stride_d;
-
-    uint ne02;
-    uint ne12;
-    uint broadcast2;
-    uint broadcast3;
-
-    uint batch_stride_a;
-    uint batch_stride_b;
-    uint batch_stride_d;
-
-#ifdef MUL_MAT_ID
-    uint expert_stride_a;
-    uint expert_stride_b0;
-    uint expert_stride_b1;
-    uint expert_stride_d0;
-    uint expert_stride_d1;
-
-    uint ne11;
-    uint nei0;
-    uint nbi1;
-    uint n_as;
-#endif
-} p;
-"""
-
-mul_mat_vec_body = """
-layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-
-layout (constant_id = 0) const uint BLOCK_SIZE = 32;
-
-shared FLOAT_TYPE tmp[BLOCK_SIZE];
-
-void main() {
-    const uint row = gl_WorkGroupID.x;
-    const uint tid = gl_LocalInvocationID.x;
-    const uint batch_idx = gl_GlobalInvocationID.y;
-#ifdef MUL_MAT_ID
-    const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0;
-    const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0;
-#endif
-
-    const uint i13 = batch_idx / p.ne12;
-    const uint i12 = batch_idx % p.ne12;
-
-    const uint i03 = i13 / p.broadcast3;
-    const uint i02 = i12 / p.broadcast2;
-
-    const uint batch_idx_a = i03 * p.ne02 + i02;
-
-#ifdef MUL_MAT_ID
-    const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0];
-#endif
-
-    const uint a_offset =
-#ifdef MUL_MAT_ID
-            expert_id * p.expert_stride_a +
-#endif
-            batch_idx_a * p.batch_stride_a;
-    const uint b_offset =
-#ifdef MUL_MAT_ID
-            (expert_idx0 % p.ne11) * p.expert_stride_b0 +
-            expert_idx1 * p.expert_stride_b1 +
-#endif
-            batch_idx * p.batch_stride_b;
-    const uint d_offset =
-#ifdef MUL_MAT_ID
-            expert_idx0 * p.expert_stride_b0 +
-            expert_idx1 * p.expert_stride_b1 +
-#endif
-            batch_idx * p.batch_stride_d;
-
-    const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
-
-    tmp[tid] = FLOAT_TYPE(0.0f);
-
-    [[unroll]] for (uint i = 0; i < p.ncols/BLOCK_SIZE; i += 2) {
-        const uint col = i*BLOCK_SIZE + 2*tid;
-        const uint ib = (row*p.ncols + col)/QUANT_K; // block index
-        const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
-        const uint iybs = col - col%QUANT_K; // y block start index
-
-        vec2 v = dequantize(ib, iqs, a_offset / QUANT_K);
-
-        // matrix multiplication
-        tmp[tid] += FLOAT_TYPE(v.x) * FLOAT_TYPE(data_b[b_offset + iybs + iqs]) +
-                    FLOAT_TYPE(v.y) * FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]);
-    }
-
-    // sum up partial sums and write back result
-    barrier();
-    [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
-        if (tid < s) {
-            tmp[tid] += tmp[tid + s];
-        }
-        barrier();
-    }
-    if (tid == 0) {
-        data_d[d_offset + row] = D_TYPE(tmp[0]);
-    }
-}
-"""
-
-# K-quants
-mul_mat_vec_q2_K_body = """
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
-
-shared FLOAT_TYPE tmp[32];
-
-void main() {
-    const uint row = gl_WorkGroupID.x;
-    const uint batch_idx = gl_GlobalInvocationID.y;
-#ifdef MUL_MAT_ID
-    const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0;
-    const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0;
-#endif
-
-    const uint i13 = batch_idx / p.ne12;
-    const uint i12 = batch_idx % p.ne12;
-
-    const uint i03 = i13 / p.broadcast3;
-    const uint i02 = i12 / p.broadcast2;
-
-    const uint batch_idx_a = i03 * p.ne02 + i02;
-
-#ifdef MUL_MAT_ID
-    const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0];
-#endif
-
-    const uint a_offset =
-#ifdef MUL_MAT_ID
-            expert_id * p.expert_stride_a +
-#endif
-            batch_idx_a * p.batch_stride_a;
-    const uint b_offset =
-#ifdef MUL_MAT_ID
-            (expert_idx0 % p.ne11) * p.expert_stride_b0 +
-            expert_idx1 * p.expert_stride_b1 +
-#endif
-            batch_idx * p.batch_stride_b;
-    const uint d_offset =
-#ifdef MUL_MAT_ID
-            expert_idx0 * p.expert_stride_b0 +
-            expert_idx1 * p.expert_stride_b1 +
-#endif
-            batch_idx * p.batch_stride_d;
-
-    const uint num_blocks_per_row = p.ncols / QUANT_K;
-    const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
-
-    const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
-    const uint ix  = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION;  // 0 or 0, 1
-
-    const uint step = 16/K_QUANTS_PER_ITERATION;            // 16 or 8
-
-    const uint v_im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
-    const uint v_in = tid - step*v_im;                      // 0...15 or 0...7
-
-    const uint l0 = K_QUANTS_PER_ITERATION*v_in;            // 0...15
-    const uint q_offset = 32*v_im + l0;
-    const uint s_offset = 8*v_im;
-    const uint y_offset = 128*v_im + l0;
-
-    tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
-
-    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
-        const uint y_idx = i * QUANT_K + y_offset;
-
-        const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
-        const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
-
-        FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
-        FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
-        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
-            sum1 += FLOAT_TYPE(data_b[b_offset + y_idx + l +  0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 0) & 3)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 0) & 3)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 2) & 3)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 2) & 3)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 4) & 3)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 4) & 3)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 6) & 3)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 6) & 3);
-            sum2 += FLOAT_TYPE(data_b[b_offset + y_idx + l +  0]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 0] >> 4) & 0xF)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 1] >> 4) & 0xF)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 2] >> 4) & 0xF)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 3] >> 4) & 0xF)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 4] >> 4) & 0xF)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 5] >> 4) & 0xF)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 6] >> 4) & 0xF)
-                  + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 7] >> 4) & 0xF);
-        }
-        tmp[16 * ix + tid] += dall * sum1 - dmin * sum2;
-    }
-
-    // sum up partial sums and write back result
-    barrier();
-    [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
-        if (tid < s) {
-            tmp[tid] += tmp[tid + s];
-        }
-        barrier();
-    }
-    if (tid == 0) {
-        data_d[d_offset + row] = D_TYPE(tmp[0]);
-    }
-}
-"""
-mul_mat_vec_q3_K_body = """
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
-
-shared FLOAT_TYPE tmp[32];
-
-void main() {
-    const uint row = gl_WorkGroupID.x;
-    const uint batch_idx = gl_GlobalInvocationID.y;
-#ifdef MUL_MAT_ID
-    const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0;
-    const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0;
-#endif
-
-    const uint i13 = batch_idx / p.ne12;
-    const uint i12 = batch_idx % p.ne12;
-
-    const uint i03 = i13 / p.broadcast3;
-    const uint i02 = i12 / p.broadcast2;
-
-    const uint batch_idx_a = i03 * p.ne02 + i02;
-
-#ifdef MUL_MAT_ID
-    const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0];
-#endif
-
-    const uint a_offset =
-#ifdef MUL_MAT_ID
-            expert_id * p.expert_stride_a +
-#endif
-            batch_idx_a * p.batch_stride_a;
-    const uint b_offset =
-#ifdef MUL_MAT_ID
-            (expert_idx0 % p.ne11) * p.expert_stride_b0 +
-            expert_idx1 * p.expert_stride_b1 +
-#endif
-            batch_idx * p.batch_stride_b;
-    const uint d_offset =
-#ifdef MUL_MAT_ID
-            expert_idx0 * p.expert_stride_b0 +
-            expert_idx1 * p.expert_stride_b1 +
-#endif
-            batch_idx * p.batch_stride_d;
-
-    const uint num_blocks_per_row = p.ncols / QUANT_K;
-    const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
-
-    const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
-    const uint ix  = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION;  // 0 or 0, 1
-
-    const uint step = 16/K_QUANTS_PER_ITERATION;            // 16 or 8
-
-    const uint v_im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
-    const uint v_in = tid - step*v_im;                      // 0...15 or 0...7
-
-    const uint8_t m = uint8_t(1 << (4 * v_im));
-
-    const uint l0 = K_QUANTS_PER_ITERATION*v_in;            // 0...15
-    const uint q_offset = 32*v_im + l0;
-    const uint y_offset = 128*v_im + l0;
-
-    tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
-
-    const uint s_shift = 4 * v_im;
-
-    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
-        const uint y_idx = i * QUANT_K + y_offset;
-
-        const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
-
-        FLOAT_TYPE sum = FLOAT_TYPE(0.0);
-        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
-            sum += FLOAT_TYPE(data_b[b_offset + y_idx + l +  0]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[0] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ]     ) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 0)) != 0) ? 0 : 4))
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[2] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 1)) != 0) ? 0 : 4))
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[4] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 2)) != 0) ? 0 : 4))
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[6] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 3)) != 0) ? 0 : 4))
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[1] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16]     ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4))
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[3] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4))
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[5] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4))
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[7] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4));
-        }
-        tmp[16 * ix + tid] += d * sum;
-    }
-
-    // sum up partial sums and write back result
-    barrier();
-    [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
-        if (tid < s) {
-            tmp[tid] += tmp[tid + s];
-        }
-        barrier();
-    }
-    if (tid == 0) {
-        data_d[d_offset + row] = D_TYPE(tmp[0]);
-    }
-}
-"""
-mul_mat_vec_q4_K_body = """
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
-
-shared FLOAT_TYPE tmp[32];
-
-void main() {
-    const uint row = gl_WorkGroupID.x;
-    const uint batch_idx = gl_GlobalInvocationID.y;
-#ifdef MUL_MAT_ID
-    const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0;
-    const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0;
-#endif
-
-    const uint i13 = batch_idx / p.ne12;
-    const uint i12 = batch_idx % p.ne12;
-
-    const uint i03 = i13 / p.broadcast3;
-    const uint i02 = i12 / p.broadcast2;
-
-    const uint batch_idx_a = i03 * p.ne02 + i02;
-
-#ifdef MUL_MAT_ID
-    const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0];
-#endif
-
-    const uint a_offset =
-#ifdef MUL_MAT_ID
-            expert_id * p.expert_stride_a +
-#endif
-            batch_idx_a * p.batch_stride_a;
-    const uint b_offset =
-#ifdef MUL_MAT_ID
-            (expert_idx0 % p.ne11) * p.expert_stride_b0 +
-            expert_idx1 * p.expert_stride_b1 +
-#endif
-            batch_idx * p.batch_stride_b;
-    const uint d_offset =
-#ifdef MUL_MAT_ID
-            expert_idx0 * p.expert_stride_b0 +
-            expert_idx1 * p.expert_stride_b1 +
-#endif
-            batch_idx * p.batch_stride_d;
-
-    const uint num_blocks_per_row = p.ncols / QUANT_K;
-    const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
-
-    const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
-    const uint ix  = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION;  // 0 or 0, 1
-
-    const uint step = 8/K_QUANTS_PER_ITERATION;             // 8 or 4
-
-    const uint il = tid/step;                               // 0...3
-    const uint ir = tid - step*il;                          // 0...7 or 0...3
-    const uint n =  2 * K_QUANTS_PER_ITERATION;             // 2 or 4
-
-    const uint v_im = il / 2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
-    const uint v_in = il % 2;
-
-    const uint l0 = n * (2 * ir + v_in);            // 0...15
-    const uint q_offset = 32*v_im + l0;
-    const uint y_offset = 64*v_im + l0;
-
-    tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
-
-    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
-        const uint y1_idx = i * QUANT_K + y_offset;
-        const uint y2_idx = y1_idx + 128;
-
-        const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
-        const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
-
-        const uint8_t sc0 = uint8_t(  data_a[ib0 + i].scales[v_im * 2    ]       & 0x3f);
-        const uint8_t sc1 = uint8_t(  data_a[ib0 + i].scales[v_im * 2 + 1]       & 0x3f);
-        const uint8_t sc2 = uint8_t(  data_a[ib0 + i].scales[v_im * 2 + 4]       & 0x3f);
-        const uint8_t sc3 = uint8_t(  data_a[ib0 + i].scales[v_im * 2 + 5]       & 0x3f);
-        const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8]       & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2    ] & 0xc0) >> 2));
-        const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9]       & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2));
-        const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
-        const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));
-
-#if K_QUANTS_PER_ITERATION == 2
-        const uint8_t q4_0  = uint8_t(data_a[ib0 + i].qs[q_offset     ] & 0xf);
-        const uint8_t q4_1  = uint8_t(data_a[ib0 + i].qs[q_offset +  1] & 0xf);
-        const uint8_t q4_2  = uint8_t(data_a[ib0 + i].qs[q_offset +  2] & 0xf);
-        const uint8_t q4_3  = uint8_t(data_a[ib0 + i].qs[q_offset +  3] & 0xf);
-        const uint8_t q4_4  = uint8_t(data_a[ib0 + i].qs[q_offset     ]  >> 4);
-        const uint8_t q4_5  = uint8_t(data_a[ib0 + i].qs[q_offset +  1]  >> 4);
-        const uint8_t q4_6  = uint8_t(data_a[ib0 + i].qs[q_offset +  2]  >> 4);
-        const uint8_t q4_7  = uint8_t(data_a[ib0 + i].qs[q_offset +  3]  >> 4);
-        const uint8_t q4_8  = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
-        const uint8_t q4_9  = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
-        const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] & 0xf);
-        const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] & 0xf);
-        const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64]  >> 4);
-        const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65]  >> 4);
-        const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 66]  >> 4);
-        const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 67]  >> 4);
-
-        const FLOAT_TYPE sx = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx]) * q4_0 + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1 + FLOAT_TYPE(data_b[b_offset + y1_idx + 2]) * q4_2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * q4_3);
-        const FLOAT_TYPE sy = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * q4_4 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_5 + FLOAT_TYPE(data_b[b_offset + y1_idx + 34]) * q4_6 + FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * q4_7);
-        const FLOAT_TYPE sz = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx]) * q4_8 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_9 + FLOAT_TYPE(data_b[b_offset + y2_idx + 2]) * q4_10 + FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * q4_11);
-        const FLOAT_TYPE sw = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * q4_12 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_13 + FLOAT_TYPE(data_b[b_offset + y2_idx + 34]) * q4_14 + FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * q4_15);
-        const FLOAT_TYPE smin = FLOAT_TYPE(
-            FLOAT_TYPE(data_b[b_offset + y1_idx    ]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx    ]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * sc7
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 2]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 34]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 2]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 34]) * sc7
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * sc7
-        );
-        tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin);
-#else
-        const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset     ] & 0xf);
-        const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset +  1] & 0xf);
-        const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset     ]  >> 4);
-        const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset +  1]  >> 4);
-        const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
-        const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
-        const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 64]  >> 4);
-        const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 65]  >> 4);
-
-        const FLOAT_TYPE sx = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx     ]) * q4_0  + FLOAT_TYPE(data_b[b_offset + y1_idx +  1]) * q4_1);
-        const FLOAT_TYPE sy = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * q4_2  + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_3);
-        const FLOAT_TYPE sz = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx     ]) * q4_4  + FLOAT_TYPE(data_b[b_offset + y2_idx +  1]) * q4_5);
-        const FLOAT_TYPE sw = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * q4_6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_7);
-        const FLOAT_TYPE smin = FLOAT_TYPE(
-            FLOAT_TYPE(data_b[b_offset + y1_idx]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * sc7
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7
-        );
-
-        tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) + sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin);
-#endif
-    }
-
-    // sum up partial sums and write back result
-    barrier();
-    [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
-        if (tid < s) {
-            tmp[tid] += tmp[tid + s];
-        }
-        barrier();
-    }
-    if (tid == 0) {
-        data_d[d_offset + row] = D_TYPE(tmp[0]);
-    }
-}
-"""
-mul_mat_vec_q5_K_body = """
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
-
-shared FLOAT_TYPE tmp[32];
-
-void main() {
-    const uint row = gl_WorkGroupID.x;
-    const uint batch_idx = gl_GlobalInvocationID.y;
-#ifdef MUL_MAT_ID
-    const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0;
-    const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0;
-#endif
-
-    const uint i13 = batch_idx / p.ne12;
-    const uint i12 = batch_idx % p.ne12;
-
-    const uint i03 = i13 / p.broadcast3;
-    const uint i02 = i12 / p.broadcast2;
-
-    const uint batch_idx_a = i03 * p.ne02 + i02;
-
-#ifdef MUL_MAT_ID
-    const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0];
-#endif
-
-    const uint a_offset =
-#ifdef MUL_MAT_ID
-            expert_id * p.expert_stride_a +
-#endif
-            batch_idx_a * p.batch_stride_a;
-    const uint b_offset =
-#ifdef MUL_MAT_ID
-            (expert_idx0 % p.ne11) * p.expert_stride_b0 +
-            expert_idx1 * p.expert_stride_b1 +
-#endif
-            batch_idx * p.batch_stride_b;
-    const uint d_offset =
-#ifdef MUL_MAT_ID
-            expert_idx0 * p.expert_stride_b0 +
-            expert_idx1 * p.expert_stride_b1 +
-#endif
-            batch_idx * p.batch_stride_d;
-
-    const uint num_blocks_per_row = p.ncols / QUANT_K;
-    const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
-
-    const uint tid = gl_LocalInvocationID.x/2;  // 0...31 or 0...16
-    const uint ix  = gl_LocalInvocationID.x%2;  // 0 or 0, 1
-
-    const uint il = tid/4;                           // 0...3
-    const uint ir = tid - 4*il;                      // 0...7 or 0...3
-
-    const uint v_im = il / 2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
-    const uint v_in = il % 2;
-
-    const uint l0 = 4*ir + 2*v_in;                   // 0...15
-    const uint q_offset = 32*v_im + l0;
-    const uint y_offset = 64*v_im + l0;
-
-    const uint8_t hm1 = uint8_t(1 << (2*v_im));
-    const uint8_t hm2 = uint8_t(hm1 << 4);
-
-    tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
-
-    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
-        const uint y1_idx = i * QUANT_K + y_offset;
-        const uint y2_idx = y1_idx + 128;
-
-        const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
-        const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
-
-        const uint8_t sc0 = uint8_t(  data_a[ib0 + i].scales[v_im * 2    ]       & 0x3f);
-        const uint8_t sc1 = uint8_t(  data_a[ib0 + i].scales[v_im * 2 + 1]       & 0x3f);
-        const uint8_t sc2 = uint8_t(  data_a[ib0 + i].scales[v_im * 2 + 4]       & 0x3f);
-        const uint8_t sc3 = uint8_t(  data_a[ib0 + i].scales[v_im * 2 + 5]       & 0x3f);
-        const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8]       & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2    ] & 0xc0) >> 2));
-        const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9]       & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2));
-        const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
-        const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));
-
-        const uint8_t q4_0  = uint8_t(data_a[ib0 + i].qs[q_offset     ] & 0xf);
-        const uint8_t q4_1  = uint8_t(data_a[ib0 + i].qs[q_offset +  1] & 0xf);
-        const uint8_t q4_2  = uint8_t(data_a[ib0 + i].qs[q_offset + 16] & 0xf);
-        const uint8_t q4_3  = uint8_t(data_a[ib0 + i].qs[q_offset + 17] & 0xf);
-        const uint8_t q4_4  = uint8_t(data_a[ib0 + i].qs[q_offset     ]  >> 4);
-        const uint8_t q4_5  = uint8_t(data_a[ib0 + i].qs[q_offset +  1]  >> 4);
-        const uint8_t q4_6  = uint8_t(data_a[ib0 + i].qs[q_offset + 16]  >> 4);
-        const uint8_t q4_7  = uint8_t(data_a[ib0 + i].qs[q_offset + 17]  >> 4);
-        const uint8_t q4_8  = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
-        const uint8_t q4_9  = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
-        const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] & 0xf);
-        const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] & 0xf);
-        const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64]  >> 4);
-        const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65]  >> 4);
-        const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 80]  >> 4);
-        const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 81]  >> 4);
-
-        const FLOAT_TYPE sx = FLOAT_TYPE(
-            FLOAT_TYPE(data_b[b_offset + y1_idx     ]) * (q4_0 + (((data_a[ib0 + i].qh[l0     ] & hm1) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y1_idx +  1]) * (q4_1 + (((data_a[ib0 + i].qh[l0 +  1] & hm1) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) * (q4_2 + (((data_a[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 17]) * (q4_3 + (((data_a[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0))
-        );
-        const FLOAT_TYPE sy = FLOAT_TYPE(
-            FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * (q4_4 + (((data_a[ib0 + i].qh[l0     ] & (hm1 << 1)) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * (q4_5 + (((data_a[ib0 + i].qh[l0 +  1] & (hm1 << 1)) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) * (q4_6 + (((data_a[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y1_idx + 49]) * (q4_7 + (((data_a[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0))
-        );
-        const FLOAT_TYPE sz = FLOAT_TYPE(
-            FLOAT_TYPE(data_b[b_offset + y2_idx     ]) * (q4_8  + (((data_a[ib0 + i].qh[l0     ] & hm2) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y2_idx +  1]) * (q4_9  + (((data_a[ib0 + i].qh[l0 +  1] & hm2) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) * (q4_10 + (((data_a[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y2_idx + 17]) * (q4_11 + (((data_a[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0))
-        );
-        const FLOAT_TYPE sw = FLOAT_TYPE(
-            FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * (q4_12 + (((data_a[ib0 + i].qh[l0     ] & (hm2 << 1)) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * (q4_13 + (((data_a[ib0 + i].qh[l0 +  1] & (hm2 << 1)) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) * (q4_14 + (((data_a[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0))
-          + FLOAT_TYPE(data_b[b_offset + y2_idx + 49]) * (q4_15 + (((data_a[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0))
-        );
-        const FLOAT_TYPE smin = FLOAT_TYPE(
-            (FLOAT_TYPE(data_b[b_offset + y1_idx]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 17])) * sc2 + (FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 49])) * sc3
-          + (FLOAT_TYPE(data_b[b_offset + y2_idx]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 17])) * sc6 + (FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 49])) * sc7
-        );
-        tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin);
-    }
-
-    // sum up partial sums and write back result
-    barrier();
-    [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
-        if (tid < s) {
-            tmp[tid] += tmp[tid + s];
-        }
-        barrier();
-    }
-    if (tid == 0) {
-        data_d[d_offset + row] = D_TYPE(tmp[0]);
-    }
-}
-"""
-mul_mat_vec_q6_K_body = """
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
-
-shared FLOAT_TYPE tmp[32];
-
-void main() {
-    const uint row = gl_WorkGroupID.x;
-    const uint batch_idx = gl_GlobalInvocationID.y;
-#ifdef MUL_MAT_ID
-    const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0;
-    const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0;
-#endif
-
-    const uint i13 = batch_idx / p.ne12;
-    const uint i12 = batch_idx % p.ne12;
-
-    const uint i03 = i13 / p.broadcast3;
-    const uint i02 = i12 / p.broadcast2;
-
-    const uint batch_idx_a = i03 * p.ne02 + i02;
-
-#ifdef MUL_MAT_ID
-    const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0];
-#endif
-
-    const uint a_offset =
-#ifdef MUL_MAT_ID
-            expert_id * p.expert_stride_a +
-#endif
-            batch_idx_a * p.batch_stride_a;
-    const uint b_offset =
-#ifdef MUL_MAT_ID
-            (expert_idx0 % p.ne11) * p.expert_stride_b0 +
-            expert_idx1 * p.expert_stride_b1 +
-#endif
-            batch_idx * p.batch_stride_b;
-    const uint d_offset =
-#ifdef MUL_MAT_ID
-            expert_idx0 * p.expert_stride_b0 +
-            expert_idx1 * p.expert_stride_b1 +
-#endif
-            batch_idx * p.batch_stride_d;
-
-    const uint num_blocks_per_row = p.ncols / QUANT_K;
-    const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
-
-    const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
-    const uint ix  = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION;  // 0 or 0, 1
-
-    const uint step = 16/K_QUANTS_PER_ITERATION;            // 16 or 8
-
-    const uint v_im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
-    const uint v_in = tid - step*v_im;                      // 0...15 or 0...7
-
-#if K_QUANTS_PER_ITERATION == 1
-    const uint l0 = v_in;                                   // 0...15
-    const uint is = 0;
-#else
-    const uint l0 = 4 * v_in;                               // 0, 4, 8, ..., 28
-    const uint is = v_in / 4;
-#endif
-
-    const uint ql_offset = 64*v_im + l0;
-    const uint qh_offset = 32*v_im + l0;
-    const uint s_offset  =  8*v_im + is;
-    const uint y_offset = 128*v_im + l0;
-
-    tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
-
-    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
-        const uint y_idx   = i * QUANT_K + y_offset;
-
-        const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
-
-#if K_QUANTS_PER_ITERATION == 1
-        FLOAT_TYPE sum = FLOAT_TYPE(data_b[b_offset + y_idx +  0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset +  0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0x03) << 4)) - 32)
-                       + FLOAT_TYPE(data_b[b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32)
-                       + FLOAT_TYPE(data_b[b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0x0c) << 2)) - 32)
-                       + FLOAT_TYPE(data_b[b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32)
-                       + FLOAT_TYPE(data_b[b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset +  0]  >> 4) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0x30) >> 0)) - 32)
-                       + FLOAT_TYPE(data_b[b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16]  >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32)
-                       + FLOAT_TYPE(data_b[b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32]  >> 4) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0xc0) >> 2)) - 32)
-                       + FLOAT_TYPE(data_b[b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48]  >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32);
-        tmp[16 * ix + tid] += sum;
-#else
-        FLOAT_TYPE sum = FLOAT_TYPE(0.0);
-        [[unroll]] for (int l = 0; l < 4; ++l) {
-            sum += FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32)
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l+32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32)
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l+64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0]  >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32)
-                 + FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32]  >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32);
-        }
-        tmp[16 * ix + tid] += sum;
-#endif
-    }
-
-    // sum up partial sums and write back result
-    barrier();
-    [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
-        if (tid < s) {
-            tmp[tid] += tmp[tid + s];
-       }
-        barrier();
-    }
-    if (tid == 0) {
-        data_d[d_offset + row] = D_TYPE(tmp[0]);
-    }
-}
-"""
-
-mul_mat_p021_src = """#version 450
-
-#extension GL_EXT_control_flow_attributes : enable
-#extension GL_EXT_shader_16bit_storage : require
-
-#define BLOCK_SIZE 32
-#define FLOAT_TYPE float
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
-
-layout (push_constant) uniform parameter
-{
-    uint ncols_x;
-    uint nrows_x;
-    uint nchannels_x;
-    uint nchannels_y;
-    uint b_offset;
-    uint d_offset;
-} p;
-
-shared FLOAT_TYPE tmp[BLOCK_SIZE];
-
-void main() {
-    const uint tid = gl_LocalInvocationID.x;
-    const uint row_x = gl_GlobalInvocationID.y;
-    const uint channel = gl_GlobalInvocationID.z;
-    const uint channel_x = channel / (p.nchannels_y / p.nchannels_x);
-
-    const uint nrows_y = p.ncols_x;
-    const uint nrows_dst = p.nrows_x;
-    const uint row_dst = row_x;
-
-    tmp[tid] = FLOAT_TYPE(0.0f);
-
-    for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
-        const uint col_x = col_x0 + tid;
-
-        if (col_x >= p.ncols_x) {
-            break;
-        }
-
-        // x is transposed and permuted
-        const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
-        const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
-
-        const uint row_y = col_x;
-
-        // y is not transposed but permuted
-        const uint iy = channel*nrows_y + row_y;
-
-        tmp[tid] += xi * FLOAT_TYPE(data_b[iy]);
-    }
-
-    // dst is not transposed and not permuted
-    const uint idst = channel*nrows_dst + row_dst;
-
-    // sum up partial sums and write back result
-    barrier();
-    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
-        if (tid < s) {
-            tmp[tid] += tmp[tid + s];
-        }
-        barrier();
-    }
-
-    if (tid == 0) {
-        dst[idst] = tmp[0];
-    }
-}
-"""
-
-
-mul_mat_nc_src = """#version 450
-
-#extension GL_EXT_control_flow_attributes : enable
-#extension GL_EXT_shader_16bit_storage : require
-
-#define BLOCK_SIZE 32
-#define FLOAT_TYPE float
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
-
-layout (push_constant) uniform parameter
-{
-    uint ncols_x;
-    uint nrows_x;
-    uint row_stride_x;
-    uint channel_stride_x;
-    uint channel_x_divisor;
-    uint b_offset;
-    uint d_offset;
-} p;
-
-shared FLOAT_TYPE tmp[BLOCK_SIZE];
-
-void main() {
-    const uint tid       = gl_LocalInvocationID.x;
-    const uint row_x     = gl_GlobalInvocationID.y;
-    const uint channel   = gl_GlobalInvocationID.z;
-    const uint channel_x = channel / p.channel_x_divisor;
-
-    const uint nrows_y   = p.ncols_x;
-    const uint nrows_dst = p.nrows_x;
-    const uint row_dst   = row_x;
-
-    const uint idst = channel*nrows_dst + row_dst;
-
-    tmp[tid] = 0.0f;
-
-    for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
-        const uint col_x = col_x0 + tid;
-
-        if (col_x >= p.ncols_x) {
-            break;
-        }
-
-        const uint row_y = col_x;
-
-        const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
-        const uint iy = channel*nrows_y + row_y;
-
-        const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
-
-        tmp[tid] += xi * FLOAT_TYPE(data_b[iy]);
-    }
-
-    // sum up partial sums and write back result
-    barrier();
-    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
-        if (tid < s) {
-            tmp[tid] += tmp[tid + s];
-        }
-        barrier();
-    }
-
-    if (tid == 0) {
-        dst[idst] = tmp[0];
-    }
-}
-"""
-
-generic_head = """
-#version 450
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout (push_constant) uniform parameter
-{
-    uint KX;
-    uint KY;
-    float param1;
-    float param2;
-} p;
-"""
-
-generic_unary_op_head = """#version 450
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout (push_constant) uniform parameter
-{
-    uint ne;
-    uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
-    uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
-    uint d_offset;
-    float param1; float param2;
-} p;"""
-
-generic_unary_op_layout = """
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};"""
-
-generic_unary_op_funcs = """
-uint src0_idx(uint idx) {
-    const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
-    const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
-    const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
-    const uint i02_offset = i02*p.ne01*p.ne00;
-    const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
-    const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
-    return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
-}
-
-uint dst_idx(uint idx) {
-    const uint i13 = idx / (p.ne12*p.ne11*p.ne10);
-    const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
-    const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10);
-    const uint i12_offset = i12*p.ne11*p.ne10;
-    const uint i11 = (idx - i13_offset - i12_offset) / p.ne10;
-    const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
-    return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;
-}"""
-
-generic_unary_op_main = """
-void main() {
-    if (gl_GlobalInvocationID.x >= p.ne) {
-        return;
-    }
-"""
-
-generic_unary_op_combined = f"{generic_unary_op_head}\n{generic_unary_op_layout}\n{generic_unary_op_funcs}\n{generic_unary_op_main}"
-
-generic_binary_op_head = """#version 450
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout (push_constant) uniform parameter
-{
-    uint ne;
-    uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
-    uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
-    uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
-    uint d_offset;
-    float param1; float param2;
-} p;"""
-
-generic_binary_op_layout = """
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};"""
-
-generic_binary_op_funcs = """
-uint src0_idx(uint idx) {
-    const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
-    const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
-    const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
-    const uint i02_offset = i02*p.ne01*p.ne00;
-    const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
-    const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
-    return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
-}
-
-uint src1_idx(uint idx) {
-    const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
-    const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
-    const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
-    const uint i02_offset = i02*p.ne01*p.ne00;
-    const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
-    const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
-
-    return (i03 % p.ne13)*p.nb13 + (i02 % p.ne12)*p.nb12 + (i01 % p.ne11)*p.nb11 + (i00 % p.ne10)*p.nb10;
-}
-
-uint dst_idx(uint idx) {
-    const uint i23 = idx / (p.ne22*p.ne21*p.ne20);
-    const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20;
-    const uint i22 = (idx - i23_offset) / (p.ne21*p.ne20);
-    const uint i22_offset = i22*p.ne21*p.ne20;
-    const uint i21 = (idx - i23_offset - i22_offset) / p.ne20;
-    const uint i20 = idx - i23_offset - i22_offset - i21*p.ne20;
-    return i23*p.nb23 + i22*p.nb22 + i21*p.nb21 + i20*p.nb20;
-}"""
-
-generic_binary_op_main = """
-void main() {
-    if (gl_GlobalInvocationID.x >= p.ne) {
-        return;
-    }
-"""
-
-generic_binary_op_combined = f"{generic_binary_op_head}\n{generic_binary_op_layout}\n{generic_binary_op_funcs}\n{generic_binary_op_main}"
-
-# MUL F32
-mul_body = """
-    data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) * FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)]));
-}
-"""
-
-# ADD
-add_body = """
-    data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) + FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)]));
-}
-"""
-
-# SCALE
-scale_body = """
-    data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) * FLOAT_TYPE(p.param1));
-}
-"""
-
-# SQR
-sqr_body = """
-    const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]);
-    data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(val * val);
-}
-"""
-
-# CLAMP
-clamp_body = """
-    const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]);
-    data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
-}
-"""
-
-# CPY
-cpy_end = """
-    data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]);
-}
-"""
-# Causes an optimization error otherwise
-cpy_f16_f16_end = """
-    data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = data_a[src0_idx(gl_GlobalInvocationID.x)];
-}
-"""
-
-# GET_ROWS
-get_rows_float_body = """
-void main() {
-    const uint i00 = gl_GlobalInvocationID.x;
-    const uint i10 = gl_GlobalInvocationID.y;
-    const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
-    const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
-
-    if (i00 >= p.ne00) {
-        return;
-    }
-
-    const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
-
-    const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
-    const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
-
-#ifndef OPTIMIZATION_ERROR_WORKAROUND
-    data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]);
-#else
-    data_d[d_offset + i00] = data_a[a_offset + i00];
-#endif
-}
-"""
-
-get_rows_body = """
-void main() {
-    const uint i00 = (gl_GlobalInvocationID.x)*2;
-    const uint i10 = gl_GlobalInvocationID.y;
-    const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
-    const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
-
-    if (i00 >= p.ne00) {
-        return;
-    }
-
-    const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
-
-    const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
-    const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
-
-    const uint ib = a_offset + i00/QUANT_K; // block index
-    const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index
-    const uint iybs = i00 - i00%QUANT_K; // dst block start index
-    const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
-
-    vec2 v = dequantize(ib, iqs, 0);
-
-    data_d[d_offset + iybs + iqs           ] = D_TYPE(v.x);
-    data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y);
-}
-"""
-
-# UNARY
-gelu_body = """
-#extension GL_EXT_control_flow_attributes : enable
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
-    const float GELU_COEF_A    = 0.044715f;
-    const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
-    const uint i = gl_GlobalInvocationID.x;
-
-    if (i >= p.KX) {
-        return;
-    }
-
-    const float xi = float(data_a[i]);
-    const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi);
-    data_d[i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1)));
-}
-"""
-
-silu_body = """
-#extension GL_EXT_control_flow_attributes : enable
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
-    const uint i = gl_GlobalInvocationID.x;
-
-    if (i >= p.KX) {
-        return;
-    }
-
-    const float xi = float(data_a[i]);
-    data_d[i] = D_TYPE(xi / (1.0f + exp(-xi)));
-}
-"""
-
-relu_body = """
-#extension GL_EXT_control_flow_attributes : enable
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
-    const uint i = gl_GlobalInvocationID.x;
-
-    if (i >= p.KX) {
-        return;
-    }
-
-    data_d[i] = max(float(data_a[i]), 0);
-}
-"""
-
-# DIAG_MASK_INF
-diag_mask_inf_head = """#version 450
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout (push_constant) uniform parameter
-{
-    uint ncols;
-    uint rows_per_channel;
-    uint n_past;
-} p;
-"""
-diag_mask_inf_body = """
-#extension GL_EXT_control_flow_attributes : enable
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
-    const uint col = gl_GlobalInvocationID.y;
-    const uint row = gl_GlobalInvocationID.x;
-
-    if (col >= p.ncols) {
-        return;
-    }
-
-    const uint i = row*p.ncols + col;
-    if (col > p.n_past + row % p.rows_per_channel) {
-        data_d[i] = D_TYPE(uintBitsToFloat(0xFF800000));
-    } else {
-        data_d[i] = D_TYPE(data_a[i]);
-    }
-}
-"""
-
-# NORMS
-norm_body = """
-#extension GL_EXT_control_flow_attributes : enable
-#define BLOCK_SIZE 512
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-shared vec2 sum[BLOCK_SIZE];
-
-void main() {
-    const uint row = gl_WorkGroupID.x;
-    const uint tid = gl_LocalInvocationID.x;
-
-    sum[tid] = vec2(0.0f, 0.0f);
-
-    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
-        const float xi = float(data_a[row*p.KX + col]);
-        sum[tid].x += xi;
-        sum[tid].y += xi * xi;
-    }
-
-    // sum up partial sums and write back result
-    barrier();
-    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
-        if (tid < s) {
-            sum[tid] += sum[tid + s];
-        }
-        barrier();
-    }
-
-    const float mean = sum[0].x / p.KX;
-    const float var = sum[0].y / p.KX - mean * mean;
-    const float inv_std = inversesqrt(var + p.param1);
-
-    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
-        data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std);
-    }
-}
-"""
-
-rms_norm_body = """
-#extension GL_EXT_control_flow_attributes : enable
-#define BLOCK_SIZE 512
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-shared FLOAT_TYPE sum[BLOCK_SIZE];
-
-void main() {
-    const uint row = gl_WorkGroupID.x;
-    const uint tid = gl_LocalInvocationID.x;
-
-    sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
-
-    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
-        const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
-        sum[tid] += xi * xi;
-    }
-
-    // sum up partial sums and write back result
-    barrier();
-    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
-        if (tid < s) {
-            sum[tid] += sum[tid + s];
-        }
-        barrier();
-    }
-
-    const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(p.KX);
-    const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
-
-    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
-        data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
-    }
-}
-"""
-
-# SOFT_MAX
-soft_max_head = """
-#version 450
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout (push_constant) uniform parameter
-{
-    uint KX;
-    uint KY;
-    uint KZ;
-    float scale;
-    float max_bias;
-    float m0;
-    float m1;
-    uint n_head_log2;
-} p;
-"""
-
-soft_max_body = """
-#extension GL_EXT_control_flow_attributes : enable
-#define BLOCK_SIZE 512
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
-layout (binding = 2) readonly buffer Z {C_TYPE data_c[];};
-layout (binding = 3) buffer D {D_TYPE data_d[];};
-
-shared FLOAT_TYPE vals[BLOCK_SIZE];
-
-void main() {
-    const uint tid = gl_LocalInvocationID.x;
-    const uint rowx = gl_WorkGroupID.x;
-    const uint rowy = rowx % p.KY;
-
-    float slope = 0.0f;
-
-    // ALiBi
-    if (p.max_bias > 0.0f) {
-        const uint h = rowx/p.KY; // head index
-
-        const float base = h < p.n_head_log2 ? p.m0 : p.m1;
-        const uint   exp  = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
-
-        slope = pow(base, exp);
-    }
-
-    // Find max
-    vals[tid] = uintBitsToFloat(0xFF800000);
-
-    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
-        vals[tid] = max(vals[tid], FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.scale + (p.KY > 0 ? FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) + (p.KZ > 0 ? slope * FLOAT_TYPE(data_c[col]) : 0.0f));
-    }
-
-    barrier();
-    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
-        if (tid < s) {
-            vals[tid] = max(vals[tid], vals[tid + s]);
-        }
-        barrier();
-    }
-
-    const FLOAT_TYPE max_val = vals[0];
-    barrier();
-
-    // Sum up values
-    vals[tid] = FLOAT_TYPE(0.0f);
-
-    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
-        const uint i = rowx * p.KX + col;
-        const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
-        vals[tid] += val;
-        data_d[i] = D_TYPE(val);
-    }
-
-    barrier();
-    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
-        if (tid < s) {
-            vals[tid] += vals[tid + s];
-        }
-        barrier();
-    }
-
-    const D_TYPE divisor = D_TYPE(vals[0]);
-
-    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
-        data_d[rowx*p.KX + col] /= divisor;
-    }
-}
-"""
-
-# ROPE
-rope_src = """
-#version 450
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer Y {int data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
-
-layout (push_constant) uniform parameter {
-    uint ncols;
-    float freq_scale;
-    uint p_delta_rows;
-    float freq_base;
-    float ext_factor;
-    float attn_factor;
-    float corr_dims[4];
-} p;
-
-float rope_yarn_ramp(const float low, const float high, const uint i0) {
-    const float y = (i0 / 2 - low) / max(0.001f, high - low);
-    return 1.0f - min(1.0f, max(0.0f, y));
-}
-
-void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta) {
-    float mscale = p.attn_factor;
-    // Get n-d rotational scaling corrected for extrapolation
-    float theta_interp = p.freq_scale * theta_extrap;
-    float theta = theta_interp;
-    if (p.ext_factor != 0.0f) {
-        float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor;
-        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
-
-        // Get n-d magnitude scaling corrected for interpolation
-        mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
-    }
-    cos_theta = cos(theta) * mscale;
-    sin_theta = sin(theta) * mscale;
-}
-
-void main() {
-    const uint col = gl_GlobalInvocationID.y * 2;
-    const uint row = gl_GlobalInvocationID.x;
-
-    if (col >= p.ncols) {
-        return;
-    }
-
-    const uint i = row*p.ncols + col;
-    const uint i2 = row/p.p_delta_rows;
-
-    const int pos = data_b[i2];
-    const float theta_base = pos * pow(p.freq_base, -float(col)/p.ncols);
-
-    float cos_theta, sin_theta;
-    rope_yarn(theta_base, col, cos_theta, sin_theta);
-
-    const float x0 = float(data_a[i + 0]);
-    const float x1 = float(data_a[i + 1]);
-
-    data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
-    data_d[i + 1] = D_TYPE(x0*sin_theta + x1*cos_theta);
-}
-"""
-
-rope_neox_src = """
-#version 450
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer Y {int data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
-
-layout (push_constant) uniform parameter {
-    uint ncols;
-    uint ndims;
-    float freq_scale;
-    uint p_delta_rows;
-    float freq_base;
-    float ext_factor;
-    float attn_factor;
-    float corr_dims[4];
-    float theta_scale;
-    float inv_ndims;
-} p;
-
-float rope_yarn_ramp(const float low, const float high, const uint i0) {
-    const float y = (i0 / 2 - low) / max(0.001f, high - low);
-    return 1.0f - min(1.0f, max(0.0f, y));
-}
-
-void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta) {
-    float mscale = p.attn_factor;
-    // Get n-d rotational scaling corrected for extrapolation
-    float theta_interp = p.freq_scale * theta_extrap;
-    float theta = theta_interp;
-    if (p.ext_factor != 0.0f) {
-        float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor;
-        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
-
-        // Get n-d magnitude scaling corrected for interpolation
-        mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
-    }
-    cos_theta = cos(theta) * mscale;
-    sin_theta = sin(theta) * mscale;
-}
-
-void main() {
-    const uint col = gl_GlobalInvocationID.y * 2;
-    const uint row = gl_GlobalInvocationID.x;
-
-    if (col >= p.ncols) {
-        return;
-    }
-
-    const uint ib = col / p.ndims;
-    const uint ic = col % p.ndims;
-
-    if (ib > 0) {
-        const uint i = row*p.ncols + ib*p.ndims + ic;
-
-        data_d[i + 0] = data_a[i + 0];
-        data_d[i + 1] = data_a[i + 1];
-
-        return;
-    }
-
-    const uint i  = row*p.ncols + ib*p.ndims + ic/2;
-    const uint i2 = row/p.p_delta_rows;
-
-    const float cur_rot = p.inv_ndims * ic - ib;
-
-    const int pos = data_b[i2];
-    const float theta_base = pos*p.freq_scale*pow(p.theta_scale, col/2.0f);
-
-    float cos_theta, sin_theta;
-    rope_yarn(theta_base, uint(cur_rot), cos_theta, sin_theta);
-
-    const float x0 = float(data_a[i + 0]);
-    const float x1 = float(data_a[i + p.ndims/2]);
-
-    data_d[i + 0]        = D_TYPE(x0*cos_theta - x1*sin_theta);
-    data_d[i + p.ndims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
-}
-"""
-
-argsort_src = """
-#version 450
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout(local_size_x = 1024, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1)          buffer D {int data_d[];};
-
-layout (push_constant) uniform parameter {
-    uint ncols;
-    bool ascending;
-} p;
-
-void swap(uint idx0, uint idx1) {
-    int tmp = data_d[idx0];
-    data_d[idx0] = data_d[idx1];
-    data_d[idx1] = tmp;
-}
-
-void main() {
-    // bitonic sort
-    const int col = int(gl_LocalInvocationID.x);
-    const uint row = gl_WorkGroupID.y;
-
-    if (col >= p.ncols) {
-        return;
-    }
-
-    const uint a_idx = row * p.ncols;
-    const uint d_idx = row * p.ncols;
-
-    // initialize indices
-    if (col < p.ncols) {
-        data_d[col] = col;
-    }
-    barrier();
-
-    for (uint k = 2; k <= p.ncols; k *= 2) {
-        for (uint j = k / 2; j > 0; j /= 2) {
-            const uint ixj = col ^ j;
-            if (ixj > col) {
-                if ((col & k) == 0) {
-                    if (p.ascending ? data_a[a_idx + data_d[d_idx + col]] > data_a[a_idx + data_d[d_idx + ixj]] : data_a[a_idx + data_d[d_idx + col]] < data_a[a_idx + data_d[d_idx + ixj]]) {
-                        swap(d_idx + col, d_idx + ixj);
-                    }
-                } else {
-                    if (p.ascending ? data_a[a_idx + data_d[d_idx + col]] < data_a[a_idx + data_d[d_idx + ixj]] : data_a[a_idx + data_d[d_idx + col]] > data_a[a_idx + data_d[d_idx + ixj]]) {
-                        swap(d_idx + col, d_idx + ixj);
-                    }
-                }
-            }
-            barrier();
-        }
-    }
-}
-"""
-
-GLSLC = "glslc"
-
-VK_NUM_TYPES = 16
-
-GGML_TYPE_F32  = 0
-GGML_TYPE_F16  = 1
-GGML_TYPE_Q4_0 = 2
-GGML_TYPE_Q4_1 = 3
-GGML_TYPE_Q5_0 = 6
-GGML_TYPE_Q5_1 = 7
-GGML_TYPE_Q8_0 = 8
-GGML_TYPE_Q8_1 = 9
-GGML_TYPE_Q2_K = 10
-GGML_TYPE_Q3_K = 11
-GGML_TYPE_Q4_K = 12
-GGML_TYPE_Q5_K = 13
-GGML_TYPE_Q6_K = 14
-GGML_TYPE_Q8_K = 15
-
-
-type_names = {
-    GGML_TYPE_F32: "f32",
-    GGML_TYPE_F16: "f16",
-    GGML_TYPE_Q4_0: "q4_0",
-    GGML_TYPE_Q4_1: "q4_1",
-    GGML_TYPE_Q5_0: "q5_0",
-    GGML_TYPE_Q5_1: "q5_1",
-    GGML_TYPE_Q8_0: "q8_0",
-    GGML_TYPE_Q8_1: "q8_1",
-    GGML_TYPE_Q2_K: "q2_K",
-    GGML_TYPE_Q3_K: "q3_K",
-    GGML_TYPE_Q4_K: "q4_K",
-    GGML_TYPE_Q5_K: "q5_K",
-    GGML_TYPE_Q6_K: "q6_K",
-    GGML_TYPE_Q8_K: "q8_K",
-}
-
-K_QUANTS_PER_ITERATION = 2
-
-ASYNCIO_CONCURRENCY = 64
-
-output_dir = gettempdir()
-
-lock = asyncio.Lock()
-shader_fnames = []
-
-
-async def string_to_spv(name, code, defines, fp16=True):
-    f = NamedTemporaryFile(mode="w", delete=False)
-    f.write(code)
-    f.flush()
-
-    name = f"{name}{'_fp32' if not fp16 else ''}"
-    fname = os.path.join(output_dir, f"{name}.comp")
-
-    cmd = [GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", f.name, "-o", fname]
-
-    cmd.extend([f"-D{key}={value}" for key, value in defines.items()])
-
-    proc = await asyncio.create_subprocess_exec(*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
-
-    stdout, stderr = await proc.communicate()
-
-    stdout = stdout.decode()
-    error = stderr.decode()
-
-    if proc.returncode:
-        # Generate preprocessed code
-        cmd = [GLSLC, "-E", f.name]
-        cmd.extend([f"-D{key}={value}" for key, value in defines.items()])
-
-        proc = await asyncio.create_subprocess_exec(*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
-
-        stdout, stderr = await proc.communicate()
-
-        logger.info(" ".join(cmd))
-
-        if proc.returncode:
-            raise RuntimeError(f"{name=} {f.name=} {stdout=} {stderr=}")
-
-        preprocessed_code = stdout.decode()
-
-        cmd.extend([f"-D{key}={value}" for key, value in defines.items()])
-        code_with_lines = "\n".join([f"{i + 1}: {line}" for i, line in enumerate(preprocessed_code.splitlines())])
-        logger.error(f"cannot compile {name}\n\n{code_with_lines}\n\n{error}")
-        f.close()
-        os.remove(f.name)
-        sys.exit(proc.returncode)
-
-    f.close()
-    os.remove(f.name)
-
-    async with lock:
-        shader_fnames.append((name, fname))
-
-
-async def main():
-    logger.info("ggml_vulkan: Generating and compiling shaders to SPIR-V")
-
-    tasks = []
-
-    stream = []
-
-    for fp16 in (False, True):
-        # mulmat
-        if fp16:
-            shader_float_type = shader_f16
-            load_vec = "8"
-            vec_type_f16 = "f16mat2x4"
-            vec_type = "mat2x4"
-        else:
-            shader_float_type = shader_f32
-            load_vec = "4"
-            vec_type_f16 = "f16vec4"
-            vec_type = "vec4"
-
-        stream.clear()
-        stream.extend((mulmat_head, shader_float_type, mulmat_body1, mulmat_load_scalar, mulmat_body2))
-        tasks.append(string_to_spv("matmul_f32", "".join(stream), {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
-        tasks.append(string_to_spv("matmul_f32_aligned", "".join(stream), {"LOAD_VEC_A": 1, "LOAD_VEC_B": load_vec, "A_TYPE": "float", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
-        tasks.append(string_to_spv("matmul_f16", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16))
-        tasks.append(string_to_spv("matmul_f16_aligned", "".join(stream), {"LOAD_VEC_A": 1, "LOAD_VEC_B": load_vec, "A_TYPE": "float16_t", "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16))
-
-        tasks.append(string_to_spv("matmul_f16_f32", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
-        tasks.append(string_to_spv("matmul_f16_f32_aligned", "".join(stream), {"LOAD_VEC_A": 1, "LOAD_VEC_B": load_vec, "A_TYPE": "float16_t", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
-        stream.clear()
-        stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_0_defines, mulmat_body1, mulmat_load_q4_0, mulmat_body2))
-        tasks.append(string_to_spv("matmul_q4_0_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q4_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
-        tasks.append(string_to_spv("matmul_q4_0_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
-        stream.clear()
-        stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_1_defines, mulmat_body1, mulmat_load_q4_1, mulmat_body2))
-        tasks.append(string_to_spv("matmul_q4_1_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q4_1", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
-        tasks.append(string_to_spv("matmul_q4_1_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_1", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
-        stream.clear()
-        stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_0_defines, mulmat_body1, mulmat_load_q5_0, mulmat_body2))
-        tasks.append(string_to_spv("matmul_q5_0_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q5_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
-        tasks.append(string_to_spv("matmul_q5_0_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
-        stream.clear()
-        stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_1_defines, mulmat_body1, mulmat_load_q5_1, mulmat_body2))
-        tasks.append(string_to_spv("matmul_q5_1_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q5_1", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
-        tasks.append(string_to_spv("matmul_q5_1_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_1", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
-        stream.clear()
-        stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q8_0_defines, mulmat_body1, mulmat_load_q8_0, mulmat_body2))
-        tasks.append(string_to_spv("matmul_q8_0_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q8_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
-        tasks.append(string_to_spv("matmul_q8_0_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q8_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
-        stream.clear()
-        stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q2_K_defines, mulmat_body1, mulmat_load_q2_K, mulmat_body2))
-        tasks.append(string_to_spv("matmul_q2_k_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q2_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
-        tasks.append(string_to_spv("matmul_q2_k_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q2_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
-        stream.clear()
-        stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q3_K_defines, mulmat_body1, mulmat_load_q3_K, mulmat_body2))
-        tasks.append(string_to_spv("matmul_q3_k_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q3_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
-        tasks.append(string_to_spv("matmul_q3_k_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q3_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
-        stream.clear()
-        stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_K_defines, mulmat_body1, mulmat_load_q4_K, mulmat_body2))
-        tasks.append(string_to_spv("matmul_q4_k_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q4_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
-        tasks.append(string_to_spv("matmul_q4_k_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
-        stream.clear()
-        stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_K_defines, mulmat_body1, mulmat_load_q5_K, mulmat_body2))
-        tasks.append(string_to_spv("matmul_q5_k_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q5_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
-        tasks.append(string_to_spv("matmul_q5_k_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
-        stream.clear()
-        stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q6_K_defines, mulmat_body1, mulmat_load_q6_K, mulmat_body2))
-        tasks.append(string_to_spv("matmul_q6_k_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q6_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
-        tasks.append(string_to_spv("matmul_q6_k_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q6_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
-        # MUL_MAT_ID
-        # stream.clear()
-        # stream.extend((mulmat_head, shader_float_type, mulmat_body1, mulmat_load_scalar, mulmat_body2))
-        # tasks.append(string_to_spv("matmul_id_f32", "".join(stream), {"MUL_MAT_ID": "1", "A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
-        # tasks.append(string_to_spv("matmul_id_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
-        # tasks.append(string_to_spv("matmul_id_f16", "".join(stream), {"MUL_MAT_ID": "1", "A_TYPE": "float16_t", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16))
-        # tasks.append(string_to_spv("matmul_id_f16_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16))
-
-        # tasks.append(string_to_spv("matmul_id_f16_f32", "".join(stream), {"MUL_MAT_ID": "1", "A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
-        # tasks.append(string_to_spv("matmul_id_f16_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
-        # stream.clear()
-        # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_0_defines, mulmat_body1, mulmat_load_q4_0, mulmat_body2))
-        # tasks.append(string_to_spv("matmul_id_q4_0_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q4_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
-        # tasks.append(string_to_spv("matmul_id_q4_0_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
-        # stream.clear()
-        # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_1_defines, mulmat_body1, mulmat_load_q4_1, mulmat_body2))
-        # tasks.append(string_to_spv("matmul_id_q4_1_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q4_1", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
-        # tasks.append(string_to_spv("matmul_id_q4_1_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_1", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
-        # stream.clear()
-        # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_0_defines, mulmat_body1, mulmat_load_q5_0, mulmat_body2))
-        # tasks.append(string_to_spv("matmul_id_q5_0_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q5_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
-        # tasks.append(string_to_spv("matmul_id_q5_0_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
-        # stream.clear()
-        # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_1_defines, mulmat_body1, mulmat_load_q5_1, mulmat_body2))
-        # tasks.append(string_to_spv("matmul_id_q5_1_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q5_1", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
-        # tasks.append(string_to_spv("matmul_id_q5_1_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_1", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
-        # stream.clear()
-        # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q8_0_defines, mulmat_body1, mulmat_load_q8_0, mulmat_body2))
-        # tasks.append(string_to_spv("matmul_id_q8_0_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q8_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
-        # tasks.append(string_to_spv("matmul_id_q8_0_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q8_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
-        # stream.clear()
-        # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q2_K_defines, mulmat_body1, mulmat_load_q2_K, mulmat_body2))
-        # tasks.append(string_to_spv("matmul_id_q2_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q2_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
-        # tasks.append(string_to_spv("matmul_id_q2_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q2_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
-        # stream.clear()
-        # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q3_K_defines, mulmat_body1, mulmat_load_q3_K, mulmat_body2))
-        # tasks.append(string_to_spv("matmul_id_q3_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q3_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
-        # tasks.append(string_to_spv("matmul_id_q3_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q3_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
-        # stream.clear()
-        # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_K_defines, mulmat_body1, mulmat_load_q4_K, mulmat_body2))
-        # tasks.append(string_to_spv("matmul_id_q4_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q4_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
-        # tasks.append(string_to_spv("matmul_id_q4_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
-        # stream.clear()
-        # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_K_defines, mulmat_body1, mulmat_load_q5_K, mulmat_body2))
-        # tasks.append(string_to_spv("matmul_id_q5_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q5_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
-        # tasks.append(string_to_spv("matmul_id_q5_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
-        # stream.clear()
-        # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q6_K_defines, mulmat_body1, mulmat_load_q6_K, mulmat_body2))
-        # tasks.append(string_to_spv("matmul_id_q6_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q6_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
-        # tasks.append(string_to_spv("matmul_id_q6_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q6_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
-    # Shaders where precision is needed, so no fp16 version
-
-    # mul mat vec
-    for i in range(0, VK_NUM_TYPES):
-        stream.clear()
-        stream.extend((mul_mat_vec_head, shader_int8_ext, shader_f32))
-
-        if i == GGML_TYPE_F16:
-            stream.extend((shader_f16_defines, mul_mat_vec_layout, shader_float_dequant_func, mul_mat_vec_body))
-        elif i == GGML_TYPE_Q4_0:
-            stream.extend((shader_q4_0_defines, mul_mat_vec_layout, shader_q4_0_dequant_func, mul_mat_vec_body))
-        elif i == GGML_TYPE_Q4_1:
-            stream.extend((shader_q4_1_defines, mul_mat_vec_layout, shader_q4_1_dequant_func, mul_mat_vec_body))
-        elif i == GGML_TYPE_Q5_0:
-            stream.extend((shader_q5_0_defines, mul_mat_vec_layout, shader_q5_0_dequant_func, mul_mat_vec_body))
-        elif i == GGML_TYPE_Q5_1:
-            stream.extend((shader_q5_1_defines, mul_mat_vec_layout, shader_q5_1_dequant_func, mul_mat_vec_body))
-        elif i == GGML_TYPE_Q8_0:
-            stream.extend((shader_q8_0_defines, mul_mat_vec_layout, shader_q8_0_dequant_func, mul_mat_vec_body))
-        elif i == GGML_TYPE_Q2_K:
-            stream.extend((shader_q2_K_defines, mul_mat_vec_layout, mul_mat_vec_q2_K_body))
-        elif i == GGML_TYPE_Q3_K:
-            stream.extend((shader_q3_K_defines, mul_mat_vec_layout, mul_mat_vec_q3_K_body))
-        elif i == GGML_TYPE_Q4_K:
-            stream.extend((shader_q4_K_defines, mul_mat_vec_layout, mul_mat_vec_q4_K_body))
-        elif i == GGML_TYPE_Q5_K:
-            stream.extend((shader_q5_K_defines, mul_mat_vec_layout, mul_mat_vec_q5_K_body))
-        elif i == GGML_TYPE_Q6_K:
-            stream.extend((shader_q6_K_defines, mul_mat_vec_layout, mul_mat_vec_q6_K_body))
-        else:
-            continue
-
-        tasks.append(string_to_spv(f"mul_mat_vec_{type_names[i]}_f32_f32", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION}))
-        tasks.append(string_to_spv(f"mul_mat_vec_{type_names[i]}_f16_f32", "".join(stream), {"B_TYPE": "float16_t", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION}))
-
-        # tasks.append(string_to_spv(f"mul_mat_vec_id_{type_names[i]}_f32", "".join(stream), {"MUL_MAT_ID": "1", "B_TYPE": "float", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION}))
-
-    # Dequant shaders
-    for i in range(0, VK_NUM_TYPES):
-        stream.clear()
-
-        stream.extend((dequant_head, shader_int8_ext, shader_f32))
-
-        if i == GGML_TYPE_F32:
-            stream.append(dequant_f32_body)
-        elif i == GGML_TYPE_Q4_0:
-            stream.extend((shader_q4_0_defines, dequant_q4_0_body))
-        elif i == GGML_TYPE_Q4_1:
-            stream.extend((shader_q4_1_defines, dequant_q4_1_body))
-        elif i == GGML_TYPE_Q5_0:
-            stream.extend((shader_q5_0_defines, dequant_q5_0_body))
-        elif i == GGML_TYPE_Q5_1:
-            stream.extend((shader_q5_1_defines, dequant_q5_1_body))
-        elif i == GGML_TYPE_Q8_0:
-            stream.extend((shader_q8_0_defines, dequant_q8_0_body))
-        elif i == GGML_TYPE_Q2_K:
-            stream.extend((shader_q2_K_defines, dequant_q2_K_body))
-        elif i == GGML_TYPE_Q3_K:
-            stream.extend((shader_q3_K_defines, dequant_q3_K_body))
-        elif i == GGML_TYPE_Q4_K:
-            stream.extend((shader_q4_K_defines, dequant_q4_K_body))
-        elif i == GGML_TYPE_Q5_K:
-            stream.extend((shader_q5_K_defines, dequant_q5_K_body))
-        elif i == GGML_TYPE_Q6_K:
-            stream.extend((shader_q6_K_defines, dequant_q6_K_body))
-        else:
-            continue
-
-        tasks.append(string_to_spv(f"dequant_{type_names[i]}", "".join(stream), {"D_TYPE": "float16_t"}))
-
-    # get_rows
-    for i in range(0, VK_NUM_TYPES):
-        stream.clear()
-        stream.extend((generic_binary_op_head, shader_int8_ext, shader_f32))
-        optimization_workaround = False
-
-        if i == GGML_TYPE_F32:
-            stream.extend((shader_f32_defines, generic_binary_op_layout, generic_binary_op_funcs, get_rows_float_body))
-        elif i == GGML_TYPE_F16:
-            stream.extend((shader_f16_defines, generic_binary_op_layout, generic_binary_op_funcs, get_rows_float_body))
-            optimization_workaround = True
-        elif i == GGML_TYPE_Q4_0:
-            stream.extend((shader_q4_0_defines, generic_binary_op_layout, shader_q4_0_dequant_func, generic_binary_op_funcs, get_rows_body))
-        elif i == GGML_TYPE_Q4_1:
-            stream.extend((shader_q4_1_defines, generic_binary_op_layout, shader_q4_1_dequant_func, generic_binary_op_funcs, get_rows_body))
-        elif i == GGML_TYPE_Q5_0:
-            stream.extend((shader_q5_0_defines, generic_binary_op_layout, shader_q5_0_dequant_func, generic_binary_op_funcs, get_rows_body))
-        elif i == GGML_TYPE_Q5_1:
-            stream.extend((shader_q5_1_defines, generic_binary_op_layout, shader_q5_1_dequant_func, generic_binary_op_funcs, get_rows_body))
-        elif i == GGML_TYPE_Q8_0:
-            stream.extend((shader_q8_0_defines, generic_binary_op_layout, shader_q8_0_dequant_func, generic_binary_op_funcs, get_rows_body))
-        else:
-            continue
-
-        if optimization_workaround:
-            tasks.append(string_to_spv(f"get_rows_{type_names[i]}", "".join(stream), {"B_TYPE": "int", "D_TYPE": "float16_t", "OPTIMIZATION_ERROR_WORKAROUND": "1"}))
-        else:
-            tasks.append(string_to_spv(f"get_rows_{type_names[i]}", "".join(stream), {"B_TYPE": "int", "D_TYPE": "float16_t"}))
-        tasks.append(string_to_spv(f"get_rows_{type_names[i]}_f32", "".join(stream), {"B_TYPE": "int", "D_TYPE": "float"}))
-
-    tasks.append(string_to_spv("mul_mat_vec_p021_f16_f32", mul_mat_p021_src, {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}))
-    tasks.append(string_to_spv("mul_mat_vec_nc_f16_f32", mul_mat_nc_src, {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}))
-
-    # Norms
-    tasks.append(string_to_spv("norm_f32", f"{generic_head}\n{shader_f32}\n{norm_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
-    tasks.append(string_to_spv("rms_norm_f32", f"{generic_head}\n{shader_f32}\n{rms_norm_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
-
-    tasks.append(string_to_spv("cpy_f32_f32", f"{generic_unary_op_combined}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float"}))
-    tasks.append(string_to_spv("cpy_f32_f16", f"{generic_unary_op_combined}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float16_t"}))
-    tasks.append(string_to_spv("cpy_f16_f16", f"{generic_unary_op_combined}\n{cpy_f16_f16_end}", {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
-
-    tasks.append(string_to_spv("add_f32", f"{generic_binary_op_combined}\n{add_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
-
-    tasks.append(string_to_spv("split_k_reduce", mulmat_split_k_reduce_src, {}))
-    tasks.append(string_to_spv("mul_f32", f"{generic_binary_op_combined}\n{mul_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
-
-    tasks.append(string_to_spv("scale_f32", f"{generic_unary_op_combined}\n{scale_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
-
-    tasks.append(string_to_spv("sqr_f32", f"{generic_unary_op_combined}\n{sqr_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
-
-    tasks.append(string_to_spv("clamp_f32", f"{generic_unary_op_combined}\n{clamp_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
-
-    tasks.append(string_to_spv("gelu_f32", f"{generic_head}\n{shader_f32}\n{gelu_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
-    tasks.append(string_to_spv("silu_f32", f"{generic_head}\n{shader_f32}\n{silu_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
-    tasks.append(string_to_spv("relu_f32", f"{generic_head}\n{shader_f32}\n{relu_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
-
-    tasks.append(string_to_spv("diag_mask_inf_f32", f"{diag_mask_inf_head}\n{shader_f32}\n{diag_mask_inf_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
-
-    tasks.append(string_to_spv("soft_max_f32", f"{soft_max_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float", "C_TYPE": "float", "D_TYPE": "float"}))
-    tasks.append(string_to_spv("soft_max_f32_f16", f"{soft_max_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float16_t", "C_TYPE": "float16_t", "D_TYPE": "float"}))
-
-    tasks.append(string_to_spv("rope_f32", rope_src, {"A_TYPE": "float", "D_TYPE": "float"}))
-    tasks.append(string_to_spv("rope_f16", rope_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
-
-    tasks.append(string_to_spv("rope_neox_f32", rope_neox_src, {"A_TYPE": "float", "D_TYPE": "float"}))
-    tasks.append(string_to_spv("rope_neox_f16", rope_neox_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
-
-    tasks.append(string_to_spv("argsort_f32", argsort_src, {"A_TYPE": "float"}))
-
-    # Helper to decorate tasks with semaphore acquisition.
-    async def withSemaphore(sem, task):
-        async with sem:
-            return await task
-
-    # Run tasks concurrently guarded by a concurrency limit.
-    sem = asyncio.Semaphore(ASYNCIO_CONCURRENCY)
-    await asyncio.gather(*(withSemaphore(sem, task) for task in tasks))
-
-    with open("ggml-vulkan-shaders.hpp", "w") as f:
-        f.write("#include <cstdint>\n\n")
-        for name, path in sorted(shader_fnames):
-
-            with open(path, "rb") as spv:
-                counter = 0
-                newline_counter = 0
-                f.write(f"unsigned char {name}_data[] = {{\n")
-                for val in spv.read():
-                    f.write(f"0x{val:02x},")
-                    newline_counter += 1
-                    counter += 1
-                    if newline_counter >= 12:
-                        newline_counter = 0
-                        f.write("\n")
-            f.write("\n};\n")
-            f.write(f"const uint64_t {name}_len = {counter};\n\n")
-            os.remove(path)
-
-
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser(description="GGML Vulkan Shader Generator")
-
-    parser.add_argument("--glslc", help="Path to glslc")
-    parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
-
-    args = parser.parse_args()
-
-    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
-
-    if args.glslc:
-        GLSLC = args.glslc
-
-    asyncio.run(main())
\ No newline at end of file
index 8475b01196386d861caa4ccb373a2728f4265000..3974845d637ab579e4e0a47e061192adf31e5fbe 100644 (file)
@@ -4,9 +4,11 @@
 #include "generic_binary_head.comp"
 
 void main() {
-    if (gl_GlobalInvocationID.x >= p.ne) {
+    const uint idx = get_idx();
+
+    if (idx >= p.ne) {
         return;
     }
 
-    data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) + FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)]));
+    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[src1_idx(idx)]));
 }
index ca272e227fd906ae64ff23237b78d5664934f218..7071302a4b65828088a23acbb6126732b0c227de 100644 (file)
@@ -4,10 +4,12 @@
 #include "generic_unary_head.comp"
 
 void main() {
-    if (gl_GlobalInvocationID.x >= p.ne) {
+    const uint idx = get_idx();
+
+    if (idx >= p.ne) {
         return;
     }
 
-    const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]);
-    data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
+    const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
+    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
 }
diff --git a/src/vulkan-shaders/concat.comp b/src/vulkan-shaders/concat.comp
new file mode 100644 (file)
index 0000000..08ab551
--- /dev/null
@@ -0,0 +1,35 @@
+#version 450
+
+#include "types.comp"
+#include "generic_binary_head.comp"
+
+void main() {
+    const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+    const int dim = p.param3;
+
+    if (idx >= p.ne) {
+        return;
+    }
+
+    const uint i3 = idx / (p.ne22*p.ne21*p.ne20);
+    const uint i3_offset = i3 * p.ne22*p.ne21*p.ne20;
+    const uint i2 = (idx - i3_offset) / (p.ne21*p.ne20);
+    const uint i2_offset = i2*p.ne21*p.ne20;
+    const uint i1 = (idx - i3_offset - i2_offset) / p.ne20;
+    const uint i0 = idx - i3_offset - i2_offset - i1*p.ne20;
+
+    uint o[4] = {0, 0, 0, 0};
+    o[dim] = dim == 0 ? p.ne00 : (dim == 1 ? p.ne01 : (dim == 2 ? p.ne02 : p.ne03));
+
+    const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00;
+    const uint src1_idx = (i3 - o[3])*p.nb13 + (i2 - o[2])*p.nb12 + (i1 - o[1])*p.nb11 + (i0 - o[0])*p.nb10;
+    const uint dst_idx = i3*p.nb23 + i2*p.nb22 + i1*p.nb21 + i0*p.nb20;
+
+    const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
+
+#ifndef OPTIMIZATION_ERROR_WORKAROUND
+    data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : data_b[src1_idx]);
+#else
+    data_d[p.d_offset + dst_idx] = is_src0 ? data_a[src0_idx] : data_b[src1_idx];
+#endif
+}
index efb55876e35c18dd6608370dff507b6879c205b2..c26917c0f9af5438fc48718413deed5a9be8d264 100644 (file)
@@ -4,13 +4,15 @@
 #include "generic_unary_head.comp"
 
 void main() {
-    if (gl_GlobalInvocationID.x >= p.ne) {
+    const uint idx = get_idx();
+
+    if (idx >= p.ne) {
         return;
     }
 
 #ifndef OPTIMIZATION_ERROR_WORKAROUND
-    data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]);
+    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(data_a[src0_idx(idx)]);
 #else
-    data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = data_a[src0_idx(gl_GlobalInvocationID.x)];
+    data_d[p.d_offset + dst_idx(idx)] = data_a[src0_idx(idx)];
 #endif
 }
index 8ee4bfc73886518871f567aa450e10d95ddb7406..8cfce58b150165039917cb474fddb5416efcceff 100644 (file)
@@ -4,9 +4,11 @@
 #include "generic_binary_head.comp"
 
 void main() {
-    if (gl_GlobalInvocationID.x >= p.ne) {
+    const uint idx = get_idx();
+
+    if (idx >= p.ne) {
         return;
     }
 
-    data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) / FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)]));
+    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) / FLOAT_TYPE(data_b[src1_idx(idx)]));
 }
index 9fe807cce9506fb3f25c3019892ea77d276e8ef1..4cc7a68ca18c51ec0c9970a5c0eb0fde113ca788 100644 (file)
@@ -13,7 +13,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
 void main() {
     const float GELU_COEF_A    = 0.044715f;
     const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
-    const uint i = gl_GlobalInvocationID.x;
+    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
 
     if (i >= p.KX) {
         return;
diff --git a/src/vulkan-shaders/gelu_quick.comp b/src/vulkan-shaders/gelu_quick.comp
new file mode 100644 (file)
index 0000000..e6e6fcf
--- /dev/null
@@ -0,0 +1,23 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+    const float GELU_QUICK_COEF = -1.702f;
+    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+    if (i >= p.KX) {
+        return;
+    }
+
+    const float x = float(data_a[i]);
+    data_d[i] = D_TYPE(x * (1.0f / (1.0f + exp(GELU_QUICK_COEF * x))));
+}
index ab45d2564aa346415250c5aca4dcb73e4a5c0da1..b6beaff1cf65adf0b9ddef08fe3fe075aca88c41 100644 (file)
@@ -7,7 +7,7 @@ layout (push_constant) uniform parameter
     uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
     uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
     uint d_offset;
-    float param1; float param2;
+    float param1; float param2; int param3;
 } p;
 
 layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
@@ -16,6 +16,10 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
 layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
 layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
 
+uint get_idx() {
+    return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+}
+
 uint src0_idx(uint idx) {
     const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
     const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
index de08de7cd84fa962eb646ab48870c1834bd01d80..eacdefc7d8aa765e4c67b8fe6551b9eddd123b6f 100644 (file)
@@ -14,6 +14,10 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
 layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
 layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
 
+uint get_idx() {
+    return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+}
+
 uint src0_idx(uint idx) {
     const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
     const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
diff --git a/src/vulkan-shaders/group_norm.comp b/src/vulkan-shaders/group_norm.comp
new file mode 100644 (file)
index 0000000..5ad9b28
--- /dev/null
@@ -0,0 +1,66 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+#define BLOCK_SIZE 512
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+shared float tmp[BLOCK_SIZE];
+
+void main() {
+    const uint group_size = p.KX;
+    const float eps = p.param1;
+
+    const uint tid = gl_LocalInvocationID.x;
+    const uint start = gl_WorkGroupID.x * group_size + tid;
+    const uint end = start + group_size;
+
+    tmp[tid] = 0.0f;
+
+    // Calculate mean
+    [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) {
+        tmp[tid] += float(data_a[col]);
+    }
+
+    // tmp up partial tmps and write back result
+    barrier();
+    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+        if (tid < s) {
+            tmp[tid] += tmp[tid + s];
+        }
+        barrier();
+    }
+
+    const float mean = tmp[0] / group_size;
+    barrier();
+    tmp[tid] = 0.0f;
+
+    // Calculate variance
+    [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) {
+        const float xi = float(data_a[col]) - mean;
+        data_d[col] = D_TYPE(xi);
+        tmp[tid] += xi * xi;
+    }
+
+    // sum up partial sums and write back result
+    barrier();
+    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+        if (tid < s) {
+            tmp[tid] += tmp[tid + s];
+        }
+        barrier();
+    }
+
+    const float variance = tmp[0] / group_size;
+    const float scale = inversesqrt(variance + eps);
+
+    [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) {
+        data_d[col] *= D_TYPE(scale);
+    }
+}
diff --git a/src/vulkan-shaders/im2col.comp b/src/vulkan-shaders/im2col.comp
new file mode 100644 (file)
index 0000000..4d48610
--- /dev/null
@@ -0,0 +1,57 @@
+#version 450
+
+#extension GL_EXT_shader_16bit_storage : require
+
+layout (push_constant) uniform parameter
+{
+    uint batch_offset; uint offset_delta;
+    uint IC;
+    uint IW; uint IH;
+    uint OW; uint OH;
+    uint KW; uint KH;
+    uint pelements;
+    uint CHW;
+    int s0; int s1;
+    int p0; int p1;
+    int d0; int d1;
+} p;
+
+#include "types.comp"
+
+#define BLOCK_SIZE 256
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+    const uint i = gl_GlobalInvocationID.x;
+    if (i >= p.pelements) {
+        return;
+    }
+
+    const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
+    const uint kx = i / ksize;
+    const uint kd = kx * ksize;
+    const uint ky = (i - kd) / p.OW;
+    const uint ix = i % p.OW;
+
+    const uint oh = gl_GlobalInvocationID.y;
+    const uint batch = gl_GlobalInvocationID.z / p.IC;
+    const uint ic = gl_GlobalInvocationID.z % p.IC;
+
+    const uint iiw = ix * p.s0 + kx * p.d0 - p.p0;
+    const uint iih = oh * p.s1 + ky * p.d1 - p.p1;
+
+    const uint offset_dst =
+        ((batch * p.OH + oh) * p.OW + ix) * p.CHW +
+        (ic * (p.KW * p.KH) + ky * p.KW + kx);
+
+    if (iih < 0 || iih >= p.IH || iiw < 0 || iiw >= p.IW) {
+        data_d[offset_dst] = D_TYPE(0.0f);
+    } else {
+        const uint offset_src = ic * p.offset_delta + batch * p.batch_offset;
+        data_d[offset_dst] = D_TYPE(data_a[offset_src + iih * p.IW + iiw]);
+    }
+}
diff --git a/src/vulkan-shaders/leaky_relu.comp b/src/vulkan-shaders/leaky_relu.comp
new file mode 100644 (file)
index 0000000..d90a99a
--- /dev/null
@@ -0,0 +1,22 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+    if (i >= p.KX) {
+        return;
+    }
+
+    const float val = float(data_a[i]);
+    data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1);
+}
index bbb0aa1d26c1bf1915cea2fc4bfbb7068b341abd..bfb61c92d688e9d7d1e58538832ffa302e17610d 100644 (file)
@@ -4,9 +4,11 @@
 #include "generic_binary_head.comp"
 
 void main() {
-    if (gl_GlobalInvocationID.x >= p.ne) {
+    const uint idx = get_idx();
+
+    if (idx >= p.ne) {
         return;
     }
 
-    data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) * FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)]));
+    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) * FLOAT_TYPE(data_b[src1_idx(idx)]));
 }
index 803dbdcb3a936bbde9d67466794a795b61f743eb..6627a50bd949a07fb621c9147c845425d10f3408 100644 (file)
@@ -14,7 +14,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
 shared vec2 sum[BLOCK_SIZE];
 
 void main() {
-    const uint row = gl_WorkGroupID.x;
+    const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
     const uint tid = gl_LocalInvocationID.x;
 
     sum[tid] = vec2(0.0f, 0.0f);
diff --git a/src/vulkan-shaders/pad.comp b/src/vulkan-shaders/pad.comp
new file mode 100644 (file)
index 0000000..a465cd5
--- /dev/null
@@ -0,0 +1,26 @@
+#version 450
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+
+void main() {
+    const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+    if (idx >= p.ne) {
+        return;
+    }
+
+    const uint i3 = idx / (p.ne12*p.ne11*p.ne10);
+    const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10;
+    const uint i2 = (idx - i3_offset) / (p.ne11*p.ne10);
+    const uint i2_offset = i2*p.ne11*p.ne10;
+    const uint i1 = (idx - i3_offset - i2_offset) / p.ne10;
+    const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10;
+
+    const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00;
+    const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10;
+
+    const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
+
+    data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : 0.0f);
+}
index 7e5baa5b8b5e5d1c31066529847ae9fb4b76b2ee..52a19b62a67db71e58a5eb47a4b4fbb3705c78c7 100644 (file)
@@ -11,7 +11,7 @@ layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
 layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
 
 void main() {
-    const uint i = gl_GlobalInvocationID.x;
+    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
 
     if (i >= p.KX) {
         return;
index cfd08d345cc83880d9322c50c844b09898262d14..b554400ba393f92227cf1ccd7c40549890e30a32 100644 (file)
@@ -14,7 +14,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
 shared FLOAT_TYPE sum[BLOCK_SIZE];
 
 void main() {
-    const uint row = gl_WorkGroupID.x;
+    const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
     const uint tid = gl_LocalInvocationID.x;
 
     sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
index 510cb7237e8a73f8f13f0e2f14de9e30cd183a3e..5cd2f668d01f3efe802c454ce225f046a31b0b7a 100644 (file)
@@ -4,9 +4,11 @@
 #include "generic_unary_head.comp"
 
 void main() {
-    if (gl_GlobalInvocationID.x >= p.ne) {
+    const uint idx = get_idx();
+
+    if (idx >= p.ne) {
         return;
     }
 
-    data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) * FLOAT_TYPE(p.param1));
+    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) * FLOAT_TYPE(p.param1));
 }
index 15920f06e4722802459de4e19f0ff5b43aa75a05..4d36f88e089bcf61d137541886c608a8286416c9 100644 (file)
@@ -11,7 +11,7 @@ layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
 layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
 
 void main() {
-    const uint i = gl_GlobalInvocationID.x;
+    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
 
     if (i >= p.KX) {
         return;
index 1b8419c7cf2a38446c9fb2c8de19a90dcf545e8e..0bd51ecab58700b229358c7f046be9e7843fb89f 100644 (file)
@@ -28,7 +28,7 @@ shared FLOAT_TYPE vals[BLOCK_SIZE];
 
 void main() {
     const uint tid = gl_LocalInvocationID.x;
-    const uint rowx = gl_WorkGroupID.x;
+    const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
     const uint rowy = rowx % p.KY;
 
     float slope = 1.0f;
index 8dd19333d4e32e33e0e37a7819f01f17b6e0d6f7..1fa118c996e04e12f07a67f927f5791464028685 100644 (file)
@@ -4,10 +4,12 @@
 #include "generic_unary_head.comp"
 
 void main() {
-    if (gl_GlobalInvocationID.x >= p.ne) {
+    const uint idx = get_idx();
+
+    if (idx >= p.ne) {
         return;
     }
 
-    const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]);
-    data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(val * val);
+    const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
+    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(val * val);
 }
index ce2f1e2f3b3e342e2a54e54287a5b0a229c4b11b..961e5ffa1f56f42b6b35a855ecc582d47ac7d298 100644 (file)
@@ -14,7 +14,7 @@ layout (constant_id = 0) const uint BLOCK_SIZE = 32;
 shared FLOAT_TYPE tmp[BLOCK_SIZE];
 
 void main() {
-    const uint row = gl_WorkGroupID.x;
+    const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
     const uint col = gl_LocalInvocationID.x;
 
     tmp[col] = FLOAT_TYPE(0.0f);
diff --git a/src/vulkan-shaders/tanh.comp b/src/vulkan-shaders/tanh.comp
new file mode 100644 (file)
index 0000000..74630dc
--- /dev/null
@@ -0,0 +1,21 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+    if (i >= p.KX) {
+        return;
+    }
+
+    data_d[i] = D_TYPE(tanh(data_a[i]));
+}
diff --git a/src/vulkan-shaders/timestep_embedding.comp b/src/vulkan-shaders/timestep_embedding.comp
new file mode 100644 (file)
index 0000000..79e065a
--- /dev/null
@@ -0,0 +1,41 @@
+#version 450
+
+#extension GL_EXT_shader_16bit_storage : require
+
+layout (push_constant) uniform parameter
+{
+    uint nb1;
+    uint dim;
+    uint max_period;
+} p;
+
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+#define BLOCK_SIZE 256
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+    const uint i = gl_WorkGroupID.y;
+    const uint j = gl_GlobalInvocationID.x;
+    const uint d_offset = i * p.nb1;
+
+    if (p.dim % 2 != 0 && j == ((p.dim + 1) / 2)) {
+        data_d[d_offset + p.dim] = 0.f;
+    }
+
+    const uint half_dim = p.dim / 2;
+    if (j >= half_dim) {
+        return;
+    }
+
+    const float timestep = float(data_a[i]);
+    const float freq = float(exp(-log(p.max_period) * j / half_dim));
+    const float arg = timestep * freq;
+    data_d[d_offset + j] = D_TYPE(cos(arg));
+    data_d[d_offset + j + half_dim] = D_TYPE(sin(arg));
+}
index d24c172cad3fca7d9768b9496dd064aac1e6d3cd..21dce72fc7dfb3e7d1a2bd80de30dbea2a8ff681 100644 (file)
@@ -6,7 +6,7 @@
 #define QUANT_K 1
 #define QUANT_R 1
 
-#ifndef LOAD_VEC_A
+#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
 #define A_TYPE float
 #elif LOAD_VEC_A == 4
 #define A_TYPE vec4
@@ -19,7 +19,7 @@
 #define QUANT_K 1
 #define QUANT_R 1
 
-#ifndef LOAD_VEC_A
+#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
 #define A_TYPE float16_t
 #elif LOAD_VEC_A == 4
 #define A_TYPE f16vec4
diff --git a/src/vulkan-shaders/upscale.comp b/src/vulkan-shaders/upscale.comp
new file mode 100644 (file)
index 0000000..511a086
--- /dev/null
@@ -0,0 +1,36 @@
+#version 450
+
+layout (push_constant) uniform parameter
+{
+    uint ne; uint d_offset;
+    uint nb00; uint nb01; uint nb02; uint nb03;
+    uint ne10; uint ne11; uint ne12; uint ne13;
+    float sf0; float sf1; float sf2; float sf3;
+} p;
+
+#include "types.comp"
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+    const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+    if (idx >= p.ne) {
+        return;
+    }
+
+    const uint i10 = idx % p.ne10;
+    const uint i11 = (idx / p.ne10) % p.ne11;
+    const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12;
+    const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13;
+
+    const uint i00 = uint(i10 / p.sf0);
+    const uint i01 = uint(i11 / p.sf1);
+    const uint i02 = uint(i12 / p.sf2);
+    const uint i03 = uint(i13 / p.sf3);
+
+    data_d[p.d_offset + idx] = D_TYPE(data_a[i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]);
+}
index c5be3754bfed383dd449e46dc7af73fc616b946e..258a1933f6b229922119e13ea12b06840cf558ea 100644 (file)
@@ -269,9 +269,12 @@ void matmul_shaders(std::vector<std::future<void>>& tasks, bool fp16, bool matmu
 
     for (const auto& tname : type_names) {
         std::string data_a_key = "DATA_A_" + to_uppercase(tname);
+        // For unaligned, load one at a time for f32/f16, or two at a time for quants
+        std::string load_vec_a_unaligned = (tname == "f32" || tname == "f16") ? "1" : "2";
+        // For aligned matmul loads
         std::string load_vec_a = (tname == "f32" || tname == "f16") ? load_vec : "2";
         tasks.push_back(std::async(std::launch::async, [=] {
-            string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16);
+            string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16);
         }));
         tasks.push_back(std::async(std::launch::async, [=] {
             string_to_spv(shader_name + "_" + tname + "_f32_aligned", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}}), fp16);
@@ -340,6 +343,9 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
     tasks.push_back(std::async(std::launch::async, [=] {
         string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
     }));
+    tasks.push_back(std::async(std::launch::async, [=] {
+        string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+    }));
     tasks.push_back(std::async(std::launch::async, [=] {
         string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
     }));
@@ -357,6 +363,9 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
     tasks.push_back(std::async(std::launch::async, [] {
         string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
     }));
+    tasks.push_back(std::async(std::launch::async, [] {
+        string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
+    }));
 
     tasks.push_back(std::async(std::launch::async, [] {
         string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
@@ -382,15 +391,42 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
         string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
     }));
 
+    tasks.push_back(std::async(std::launch::async, [] {
+        string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+    }));
+
+    tasks.push_back(std::async(std::launch::async, [] {
+        string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
+    }));
+    tasks.push_back(std::async(std::launch::async, [] {
+        string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
+    }));
+    tasks.push_back(std::async(std::launch::async, [] {
+        string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}});
+    }));
+
+    tasks.push_back(std::async(std::launch::async, [] {
+        string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
+    }));
+
     tasks.push_back(std::async(std::launch::async, [] {
         string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
     }));
+    tasks.push_back(std::async(std::launch::async, [] {
+        string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+    }));
     tasks.push_back(std::async(std::launch::async, [] {
         string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
     }));
     tasks.push_back(std::async(std::launch::async, [] {
         string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
     }));
+    tasks.push_back(std::async(std::launch::async, [] {
+        string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+    }));
+    tasks.push_back(std::async(std::launch::async, [] {
+        string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+    }));
 
     tasks.push_back(std::async(std::launch::async, [] {
         string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
@@ -424,6 +460,17 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
     tasks.push_back(std::async(std::launch::async, [=] {
         string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
     }));
+
+    tasks.push_back(std::async(std::launch::async, [=] {
+        string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+    }));
+    tasks.push_back(std::async(std::launch::async, [=] {
+        string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
+    }));
+
+    tasks.push_back(std::async(std::launch::async, [=] {
+        string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+    }));
 }
 
 void write_output_files() {