]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: implement several ops relevant for ggml_opt (llama/11769)
authorRémy O <redacted>
Mon, 17 Feb 2025 06:55:57 +0000 (07:55 +0100)
committerGeorgi Gerganov <redacted>
Thu, 27 Feb 2025 06:55:36 +0000 (08:55 +0200)
* vulkan: support memset_tensor

* vulkan: support GGML_OP_SUM

* vulkan: implement GGML_OP_ARGMAX

* vulkan: implement GGML_OP_SUB

* vulkan: implement GGML_OP_COUNT_EQUAL

* vulkan: implement GGML_OP_OPT_STEP_ADAMW

* vulkan: fix check_results RWKV_WKV6 crash and memory leaks

* vulkan: implement GGML_OP_REPEAT_BACK

* tests: remove invalid test-backend-ops REPEAT_BACK tests

* vulkan: fix COUNT_EQUAL memset using a fillBuffer command

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/sub.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index 88f31c1ef8b2f7dc4ff950715e93782c108b6dd6..131ee1ea044dd9df6e93d7c3bae5ae5368d8877d 100644 (file)
@@ -222,6 +222,7 @@ struct vk_device_struct {
     vk_pipeline pipeline_acc_f32;
     vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat;
     vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat;
+    vk_pipeline pipeline_sub_f32, pipeline_sub_f32_norepeat;
     vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat;
     vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat;
     vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
@@ -232,7 +233,7 @@ struct vk_device_struct {
     vk_pipeline pipeline_cos_f32;
     vk_pipeline pipeline_clamp_f32;
     vk_pipeline pipeline_pad_f32;
-    vk_pipeline pipeline_repeat_f32;
+    vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
     vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
     vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
     vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
@@ -255,10 +256,13 @@ struct vk_device_struct {
     vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
     vk_pipeline pipeline_argsort_f32;
     vk_pipeline pipeline_sum_rows_f32;
+    vk_pipeline pipeline_argmax_f32;
+    vk_pipeline pipeline_count_equal_i32;
     vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
     vk_pipeline pipeline_timestep_embedding_f32;
     vk_pipeline pipeline_pool2d_f32;
     vk_pipeline pipeline_rwkv_wkv6_f32;
+    vk_pipeline pipeline_opt_step_adamw_f32;
 
     // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
     vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
@@ -2147,6 +2151,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
 
+    ggml_vk_create_pipeline(device, device->pipeline_sub_f32, "sub_f32", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_sub_f32_norepeat, "sub_f32_norepeat", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
@@ -2169,6 +2175,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
@@ -2203,8 +2210,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
 
+    ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+
     ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
 
+    ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
+
     ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
     if (device->float_controls_rte_fp16) {
         ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
@@ -2218,6 +2229,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
 
+    ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+
     for (auto &c : compiles) {
         c.wait();
     }
@@ -3783,6 +3796,12 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
     }
 }
 
+static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
+    VK_LOG_DEBUG("ggml_vk_buffer_memset_async(" << offset << ", " << c << ", " << size << ")");
+
+    ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
+}
+
 static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
     VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
 
@@ -5189,6 +5208,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16;
         }
         return nullptr;
+    case GGML_OP_SUB:
+        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+            return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_f32_norepeat : ctx->device->pipeline_sub_f32;
+        }
+        return nullptr;
     case GGML_OP_MUL:
         if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
             return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32;
@@ -5250,6 +5274,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_repeat_f32;
         }
         return nullptr;
+    case GGML_OP_REPEAT_BACK:
+        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+            return ctx->device->pipeline_repeat_back_f32;
+        }
+        return nullptr;
     case GGML_OP_CPY:
     case GGML_OP_CONT:
     case GGML_OP_DUP:
@@ -5358,11 +5387,22 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_argsort_f32;
         }
         return nullptr;
+    case GGML_OP_SUM:
     case GGML_OP_SUM_ROWS:
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
             return ctx->device->pipeline_sum_rows_f32;
         }
         return nullptr;
+    case GGML_OP_ARGMAX:
+        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
+            return ctx->device->pipeline_argmax_f32;
+        }
+        return nullptr;
+    case GGML_OP_COUNT_EQUAL:
+        if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I64) {
+            return ctx->device->pipeline_count_equal_i32;
+        }
+        return nullptr;
     case GGML_OP_IM2COL:
         if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
             return ctx->device->pipeline_im2col_f32;
@@ -5386,6 +5426,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_rwkv_wkv6_f32;
         }
         return nullptr;
+    case GGML_OP_OPT_STEP_ADAMW:
+        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+            return ctx->device->pipeline_opt_step_adamw_f32;
+        }
+        return nullptr;
     case GGML_OP_LEAKY_RELU:
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
             return ctx->device->pipeline_leaky_relu_f32;
@@ -5403,6 +5448,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
     case GGML_OP_CPY:
     case GGML_OP_GET_ROWS:
     case GGML_OP_ADD:
+    case GGML_OP_SUB:
     case GGML_OP_MUL:
     case GGML_OP_DIV:
     case GGML_OP_CONCAT:
@@ -5413,6 +5459,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
     case GGML_OP_CLAMP:
     case GGML_OP_PAD:
     case GGML_OP_REPEAT:
+    case GGML_OP_REPEAT_BACK:
     case GGML_OP_ROPE:
         return true;
     default:
@@ -5627,6 +5674,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
     case GGML_OP_RMS_NORM:
     case GGML_OP_SOFT_MAX:
     case GGML_OP_SUM_ROWS:
+    case GGML_OP_ARGMAX:
         {
             const uint32_t nr = ggml_nrows(src0);
             if (nr > 262144) {
@@ -5637,6 +5685,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
                 elements = { nr, 1, 1 };
             }
         } break;
+    case GGML_OP_SUM:
+        // We use GGML_OP_SUM_ROWS with 1 row.
+        elements = { 1, 1, 1 };
+        break;
     case GGML_OP_GROUP_NORM:
         {
             const uint32_t num_groups = dst->op_params[0];
@@ -5683,6 +5735,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
             elements = { N * OC * OH * OW, 1, 1};
         } break;
     case GGML_OP_ADD:
+    case GGML_OP_SUB:
     case GGML_OP_DIV:
     case GGML_OP_MUL:
     case GGML_OP_SCALE:
@@ -5692,6 +5745,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
     case GGML_OP_CLAMP:
     case GGML_OP_PAD:
     case GGML_OP_REPEAT:
+    case GGML_OP_REPEAT_BACK:
     case GGML_OP_CPY:
     case GGML_OP_CONCAT:
     case GGML_OP_UPSCALE:
@@ -5752,6 +5806,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
         // im2col uses only src1 and dst buffers
         ggml_vk_sync_buffers(subctx);
         ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
+    } else if (op == GGML_OP_COUNT_EQUAL) {
+        ggml_vk_sync_buffers(subctx);
+        // count_equal assumes that destination buffer is initialized with zeroes
+        ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
+        ggml_vk_sync_buffers(subctx);
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
     } else if (use_src2) {
         ggml_vk_sync_buffers(subctx);
         ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
@@ -5814,6 +5874,21 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
     }, dryrun);
 }
 
+static void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
+    const uint32_t src0_type_size = ggml_type_size(src0->type);
+    const uint32_t src1_type_size = ggml_type_size(src1->type);
+    const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SUB, {
+        (uint32_t)ggml_nelements(src0),
+        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
+        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
+        0,
+        0.0f, 0.0f, 0,
+    }, dryrun);
+}
+
 static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
     const uint32_t src0_type_size = ggml_type_size(src0->type);
     const uint32_t src1_type_size = ggml_type_size(src1->type);
@@ -5972,6 +6047,111 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
     );
 }
 
+static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) {
+    const ggml_tensor * x = dst->src[0];
+    const ggml_tensor * g = dst->src[1];
+    const ggml_tensor * gm = dst->src[2];
+    const ggml_tensor * gv = dst->src[3];
+    const ggml_tensor * p = dst->src[4];
+
+    GGML_ASSERT(x->type == GGML_TYPE_F32);
+    GGML_ASSERT(g->type == GGML_TYPE_F32);
+    GGML_ASSERT(gm->type == GGML_TYPE_F32);
+    GGML_ASSERT(gv->type == GGML_TYPE_F32);
+    GGML_ASSERT(p->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->buffer != nullptr);
+    GGML_ASSERT(ggml_is_contiguous(x));
+    GGML_ASSERT(ggml_is_contiguous(g));
+    GGML_ASSERT(ggml_is_contiguous(gm));
+    GGML_ASSERT(ggml_is_contiguous(gv));
+    GGML_ASSERT(ggml_is_contiguous(p));
+    GGML_ASSERT(ggml_are_same_shape(x, g));
+    GGML_ASSERT(ggml_are_same_shape(x, gm));
+    GGML_ASSERT(ggml_are_same_shape(x, gv));
+    GGML_ASSERT(ggml_nelements(p) == 7);
+
+    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, g, gm, gv, dst, GGML_OP_OPT_STEP_ADAMW);
+    GGML_ASSERT(pipeline != nullptr);
+
+    if (dryrun) {
+        ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
+        return;
+    }
+
+    ggml_backend_vk_buffer_context * x_buf_ctx = (ggml_backend_vk_buffer_context *)x->buffer->context;
+    ggml_backend_vk_buffer_context * g_buf_ctx = (ggml_backend_vk_buffer_context *)g->buffer->context;
+    ggml_backend_vk_buffer_context * gm_buf_ctx = (ggml_backend_vk_buffer_context *)gm->buffer->context;
+    ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer->context;
+    ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer->context;
+
+    ggml_vk_sync_buffers(subctx);
+
+    vk_buffer d_X = nullptr, d_G = nullptr, d_GM = nullptr, d_GV = nullptr, d_P = nullptr;
+    size_t x_offset = 0, g_offset = 0, gm_offset = 0, gv_offset = 0, p_offset = 0;
+    bool X_uma = false, G_uma = false, GM_uma = false, GV_uma = false, P_uma = false;
+
+    if (ctx->device->uma) {
+        ggml_vk_host_get(ctx->device, x->data, d_X, x_offset);
+        ggml_vk_host_get(ctx->device, g->data, d_G, g_offset);
+        ggml_vk_host_get(ctx->device, gm->data, d_GM, gm_offset);
+        ggml_vk_host_get(ctx->device, gv->data, d_GV, gv_offset);
+        ggml_vk_host_get(ctx->device, p->data, d_P, p_offset);
+
+        X_uma = d_X != nullptr;
+        G_uma = d_G != nullptr;
+        GM_uma = d_GM != nullptr;
+        GV_uma = d_GV != nullptr;
+        P_uma = d_P != nullptr;
+    }
+
+    if (!X_uma) {
+        d_X = x_buf_ctx->dev_buffer;
+        x_offset = vk_tensor_offset(x) + x->view_offs;
+    }
+    if (!G_uma) {
+        d_G = g_buf_ctx->dev_buffer;
+        g_offset = vk_tensor_offset(g) + g->view_offs;
+    }
+    if (!GM_uma) {
+        d_GM = gm_buf_ctx->dev_buffer;
+        gm_offset = vk_tensor_offset(gm) + gm->view_offs;
+    }
+    if (!GV_uma) {
+        d_GV = gv_buf_ctx->dev_buffer;
+        gv_offset = vk_tensor_offset(gv) + gv->view_offs;
+    }
+    if (!P_uma) {
+        d_P = p_buf_ctx->dev_buffer;
+        p_offset = vk_tensor_offset(p) + p->view_offs;
+    }
+
+    const uint64_t x_size = ggml_nbytes(x);
+    const uint64_t g_size = ggml_nbytes(g);
+    const uint64_t gm_size = ggml_nbytes(gm);
+    const uint64_t gv_size = ggml_nbytes(gv);
+    const uint64_t p_size = ggml_nbytes(p);
+
+    std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(x), 1, 1 };
+
+    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
+        vk_subbuffer{ d_X, x_offset, x_size },
+        vk_subbuffer{ d_G, g_offset, g_size },
+        vk_subbuffer{ d_GM, gm_offset, gm_size },
+        vk_subbuffer{ d_GV, gv_offset, gv_size },
+        vk_subbuffer{ d_P, p_offset, p_size },
+    }, sizeof(vk_op_push_constants), &pc, elements);
+}
+
+static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
+    const size_t n = ggml_nelements(dst->src[0]);
+
+    ggml_vk_op_f32_opt_step_adamw(
+        ctx, subctx, dst,
+        { (uint32_t)n, 0, 0.0f, 0.0f },
+        dryrun
+    );
+}
+
 static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
     int * op_params = (int *)dst->op_params;
 
@@ -6105,6 +6285,20 @@ static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, co
     }, dryrun);
 }
 
+static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
+    const uint32_t src0_type_size = ggml_type_size(src0->type);
+    const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+    ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, {
+        (uint32_t)ggml_nelements(dst),
+        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
+        0,
+        0.0f, 0.0f,
+        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+    }, dryrun);
+}
+
 static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
     const uint32_t src0_type_size = ggml_type_size(src0->type);
     const uint32_t dst_type_size = ggml_type_size(dst->type);
@@ -6227,10 +6421,22 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
     }, dryrun);
 }
 
+static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
+}
+
 static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
     ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
 }
 
+static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
+}
+
+static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
+}
+
 static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
     const int32_t s0 = dst->op_params[0];
     const int32_t s1 = dst->op_params[1];
@@ -7095,9 +7301,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
         }
         break;
     case GGML_OP_REPEAT:
+    case GGML_OP_REPEAT_BACK:
     case GGML_OP_GET_ROWS:
     case GGML_OP_ADD:
     case GGML_OP_ACC:
+    case GGML_OP_SUB:
     case GGML_OP_MUL:
     case GGML_OP_DIV:
     case GGML_OP_CONCAT:
@@ -7120,13 +7328,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
     case GGML_OP_MUL_MAT:
     case GGML_OP_MUL_MAT_ID:
     case GGML_OP_ARGSORT:
+    case GGML_OP_SUM:
     case GGML_OP_SUM_ROWS:
+    case GGML_OP_ARGMAX:
+    case GGML_OP_COUNT_EQUAL:
     case GGML_OP_IM2COL:
     case GGML_OP_TIMESTEP_EMBEDDING:
     case GGML_OP_POOL_2D:
     case GGML_OP_RWKV_WKV6:
     case GGML_OP_LEAKY_RELU:
     case GGML_OP_FLASH_ATTN_EXT:
+    case GGML_OP_OPT_STEP_ADAMW:
         break;
     default:
         std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
@@ -7147,9 +7359,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
     } else {
         switch (node->op) {
         case GGML_OP_REPEAT:
+        case GGML_OP_REPEAT_BACK:
         case GGML_OP_ACC:
         case GGML_OP_GET_ROWS:
         case GGML_OP_ADD:
+        case GGML_OP_SUB:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
         case GGML_OP_CONCAT:
@@ -7171,7 +7385,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
         case GGML_OP_SOFT_MAX:
         case GGML_OP_ROPE:
         case GGML_OP_ARGSORT:
+        case GGML_OP_SUM:
         case GGML_OP_SUM_ROWS:
+        case GGML_OP_ARGMAX:
+        case GGML_OP_COUNT_EQUAL:
         case GGML_OP_IM2COL:
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_POOL_2D:
@@ -7192,6 +7409,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
     case GGML_OP_REPEAT:
         ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun);
 
+        break;
+    case GGML_OP_REPEAT_BACK:
+        ggml_vk_repeat_back(ctx, compute_ctx, src0, node, dryrun);
+
         break;
     case GGML_OP_ACC:
         ggml_vk_acc(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -7204,6 +7425,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
     case GGML_OP_ADD:
         ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);
 
+        break;
+    case GGML_OP_SUB:
+        ggml_vk_sub(ctx, compute_ctx, src0, src1, node, dryrun);
+
         break;
     case GGML_OP_MUL:
         ggml_vk_mul(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -7291,10 +7516,22 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
     case GGML_OP_ARGSORT:
         ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
 
+        break;
+    case GGML_OP_SUM:
+        ggml_vk_sum(ctx, compute_ctx, src0, node, dryrun);
+
         break;
     case GGML_OP_SUM_ROWS:
         ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun);
 
+        break;
+    case GGML_OP_ARGMAX:
+        ggml_vk_argmax(ctx, compute_ctx, src0, node, dryrun);
+
+        break;
+    case GGML_OP_COUNT_EQUAL:
+        ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node, dryrun);
+
         break;
     case GGML_OP_IM2COL:
         ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -7329,6 +7566,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
     case GGML_OP_RWKV_WKV6:
         ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun);
 
+        break;
+
+    case GGML_OP_OPT_STEP_ADAMW:
+        ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
+
         break;
     default:
         return false;
@@ -7380,6 +7622,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
     case GGML_OP_ADD:
     case GGML_OP_ACC:
     case GGML_OP_GET_ROWS:
+    case GGML_OP_SUB:
     case GGML_OP_MUL:
     case GGML_OP_DIV:
     case GGML_OP_CONCAT:
@@ -7405,13 +7648,18 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
     case GGML_OP_TRANSPOSE:
     case GGML_OP_NONE:
     case GGML_OP_ARGSORT:
+    case GGML_OP_SUM:
     case GGML_OP_SUM_ROWS:
+    case GGML_OP_ARGMAX:
+    case GGML_OP_COUNT_EQUAL:
     case GGML_OP_IM2COL:
     case GGML_OP_TIMESTEP_EMBEDDING:
     case GGML_OP_POOL_2D:
     case GGML_OP_RWKV_WKV6:
     case GGML_OP_LEAKY_RELU:
     case GGML_OP_REPEAT:
+    case GGML_OP_REPEAT_BACK:
+    case GGML_OP_OPT_STEP_ADAMW:
         buf = tensor->buffer;
 
         break;
@@ -7603,6 +7851,15 @@ static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggm
     }
 }
 
+static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+    VK_LOG_DEBUG("ggml_backend_vk_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")");
+    ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
+    vk_buffer buf = buf_ctx->dev_buffer;
+
+    uint32_t val32 = (uint32_t)value * 0x01010101;
+    ggml_vk_buffer_memset(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, val32, size);
+}
+
 static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
     VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
     ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
@@ -7647,7 +7904,7 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
     /* .free_buffer     = */ ggml_backend_vk_buffer_free_buffer,
     /* .get_base        = */ ggml_backend_vk_buffer_get_base,
     /* .init_tensor     = */ ggml_backend_vk_buffer_init_tensor,
-    /* .memset_tensor   = */ NULL,
+    /* .memset_tensor   = */ ggml_backend_vk_buffer_memset_tensor,
     /* .set_tensor      = */ ggml_backend_vk_buffer_set_tensor,
     /* .get_tensor      = */ ggml_backend_vk_buffer_get_tensor,
     /* .cpy_tensor      = */ ggml_backend_vk_buffer_cpy_tensor,
@@ -8300,6 +8557,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
             } break;
         case GGML_OP_REPEAT:
             return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
+        case GGML_OP_REPEAT_BACK:
+            return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_ROPE:
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
@@ -8313,6 +8572,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
             return ggml_is_contiguous(op->src[0]);
         case GGML_OP_ADD:
         case GGML_OP_ACC:
+        case GGML_OP_SUB:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
         case GGML_OP_CONCAT:
@@ -8326,12 +8586,16 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_SOFT_MAX:
         case GGML_OP_ARGSORT:
+        case GGML_OP_SUM:
         case GGML_OP_SUM_ROWS:
+        case GGML_OP_ARGMAX:
+        case GGML_OP_COUNT_EQUAL:
         case GGML_OP_IM2COL:
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_POOL_2D:
         case GGML_OP_RWKV_WKV6:
         case GGML_OP_LEAKY_RELU:
+        case GGML_OP_OPT_STEP_ADAMW:
             return true;
         default:
             return false;
@@ -8604,8 +8868,6 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
 
     ggml_tensor * src0 = tensor->src[0];
     ggml_tensor * src1 = tensor->src[1];
-    ggml_tensor * src2 = tensor->src[2];
-    ggml_tensor * src3 = tensor->src[3];
 
     struct ggml_init_params iparams = {
         /*.mem_size   =*/ 2ul*1024ul*1024ul*1024ul,
@@ -8615,238 +8877,113 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
 
     struct ggml_context * ggml_ctx = ggml_init(iparams);
 
-    struct ggml_tensor * src0_clone = nullptr;
-    struct ggml_tensor * src1_clone = nullptr;
-    struct ggml_tensor * src2_clone = nullptr;
-    struct ggml_tensor * src3_clone = nullptr;
-    struct ggml_tensor * tensor_clone = nullptr;
-
-    size_t src0_size;
-    size_t src1_size;
-    size_t src2_size;
-    size_t src3_size;
-
-    void * src0_buffer = nullptr;
-    void * src1_buffer = nullptr;
-    void * src2_buffer = nullptr;
-    void * src3_buffer = nullptr;
-
-    if (src0 != nullptr) {
-        src0_clone = ggml_dup_tensor(ggml_ctx, src0);
-
-        src0_size = ggml_nbytes(src0);
+    std::array<struct ggml_tensor *, 6> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
+    std::array<size_t, 6> src_size = {0, 0, 0, 0, 0, 0};
+    std::array<void *, 6> src_buffer = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
+    const char * srci_name[6] = {"src0", "src1", "src2", "src3", "src4", "src5"};
 
-        src0_buffer = malloc(src0_size);
-        src0_clone->data = src0_buffer;
-        if (ggml_backend_buffer_is_host(src0->buffer)) {
-            memcpy(src0_clone->data, src0->data, src0_size);
-            memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
-        } else if (ggml_backend_buffer_is_vk(src0->buffer)) {
-            ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
-            vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
-            uint64_t offset = vk_tensor_offset(src0) + src0->view_offs;
-            if (!ggml_is_contiguous(src0) && ggml_vk_dim01_contiguous(src0)) {
-                for (int i3 = 0; i3 < src0->ne[3]; i3++) {
-                    for (int i2 = 0; i2 < src0->ne[2]; i2++) {
-                        const int idx = i3*src0->ne[2] + i2;
-                        ggml_vk_buffer_read(buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]);
-                    }
-                }
-
-                src0_clone->nb[0] = src0->nb[0];
-                src0_clone->nb[1] = src0->nb[1];
-                for (int i = 2; i < GGML_MAX_DIMS; i++) {
-                    src0_clone->nb[i] = src0_clone->nb[i - 1]*src0_clone->ne[i - 1];
-                }
-            } else {
-                if (offset + src0_size >= buffer_gpu->size) {
-                    src0_size = buffer_gpu->size - offset;
-                }
-                ggml_vk_buffer_read(buffer_gpu, offset, src0_clone->data, src0_size);
-                memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
-            }
-        } else {
-            GGML_ABORT("fatal error");
-        }
-
-        if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
-            ggml_vk_print_tensor(src0, "src0");
-        }
-    }
-    if (src1 != nullptr) {
-        src1_clone = ggml_dup_tensor(ggml_ctx, src1);
-
-        src1_size = ggml_nbytes(src1);
-
-        src1_buffer = malloc(src1_size);
-        src1_clone->data = src1_buffer;
-        if (ggml_backend_buffer_is_host(src1->buffer)) {
-            memcpy(src1_clone->data, src1->data, src1_size);
-            memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
-        } else if (ggml_backend_buffer_is_vk(src1->buffer)) {
-            ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
-            vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
-            uint64_t offset = vk_tensor_offset(src1) + src1->view_offs;
-            if (!ggml_is_contiguous(src1) && ggml_vk_dim01_contiguous(src1)) {
-                for (int i3 = 0; i3 < src1->ne[3]; i3++) {
-                    for (int i2 = 0; i2 < src1->ne[2]; i2++) {
-                        const int idx = i3*src1->ne[2] + i2;
-                        ggml_vk_buffer_read(buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]);
-                    }
-                }
-
-                src1_clone->nb[0] = src1->nb[0];
-                src1_clone->nb[1] = src1->nb[1];
-                for (int i = 2; i < GGML_MAX_DIMS; i++) {
-                    src1_clone->nb[i] = src1_clone->nb[i - 1]*src1_clone->ne[i - 1];
-                }
-            } else {
-                if (offset + src1_size >= buffer_gpu->size) {
-                    src1_size = buffer_gpu->size - offset;
-                }
-                ggml_vk_buffer_read(buffer_gpu, offset, src1_clone->data, src1_size);
-                memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
-            }
-        } else {
-            GGML_ABORT("fatal error");
-        }
-
-        if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
-            ggml_vk_print_tensor(src1, "src1");
-        }
-    }
-    if (src2 != nullptr) {
-        src2_clone = ggml_dup_tensor(ggml_ctx, src2);
-
-        src2_size = ggml_nbytes(src2);
-
-        src2_buffer = malloc(src2_size);
-        src2_clone->data = src2_buffer;
-        if (ggml_backend_buffer_is_host(src2->buffer)) {
-            memcpy(src2_clone->data, src2->data, src2_size);
-            memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
-        } else if (ggml_backend_buffer_is_vk(src2->buffer)) {
-            ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src2->buffer->context;
-            vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
-            uint64_t offset = vk_tensor_offset(src2) + src2->view_offs;
-            if (!ggml_is_contiguous(src2) && ggml_vk_dim01_contiguous(src2)) {
-                for (int i3 = 0; i3 < src2->ne[3]; i3++) {
-                    for (int i2 = 0; i2 < src2->ne[2]; i2++) {
-                        const int idx = i3*src2->ne[2] + i2;
-                        ggml_vk_buffer_read(buffer_gpu, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]);
-                    }
-                }
-
-                src2_clone->nb[0] = src2->nb[0];
-                src2_clone->nb[1] = src2->nb[1];
-                for (int i = 2; i < GGML_MAX_DIMS; i++) {
-                    src2_clone->nb[i] = src2_clone->nb[i - 1]*src2_clone->ne[i - 1];
-                }
-            } else {
-                if (offset + src2_size >= buffer_gpu->size) {
-                    src2_size = buffer_gpu->size - offset;
-                }
-                ggml_vk_buffer_read(buffer_gpu, offset, src2_clone->data, src2_size);
-                memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
-            }
-        } else {
-            GGML_ABORT("fatal error");
-        }
+    struct ggml_tensor * tensor_clone = nullptr;
 
-        if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
-            ggml_vk_print_tensor(src2, "src2");
+    for (int i = 0; i < 6; i++) {
+        ggml_tensor * srci = tensor->src[i];
+        if (srci == nullptr) {
+            continue;
         }
-    }
-    if (src3 != nullptr) {
-        src3_clone = ggml_dup_tensor(ggml_ctx, src3);
+        ggml_tensor * srci_clone = ggml_dup_tensor(ggml_ctx, srci);
+        size_t srci_size = ggml_nbytes(srci);
 
-        src3_size = ggml_nbytes(src3);
+        src_clone[i] = srci_clone;
+        src_size[i] = ggml_nbytes(srci);
+        src_buffer[i] = malloc(srci_size);
 
-        src3_buffer = malloc(src3_size);
-        src3_clone->data = src3_buffer;
-        if (ggml_backend_buffer_is_host(src3->buffer)) {
-            memcpy(src3_clone->data, src3->data, src3_size);
-            memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS);
-        } else if (ggml_backend_buffer_is_vk(src3->buffer)) {
-            ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src3->buffer->context;
+        srci_clone->data = src_buffer[i];
+        if (ggml_backend_buffer_is_host(srci->buffer)) {
+            memcpy(srci_clone->data, srci->data, srci_size);
+            memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS);
+        } else if (ggml_backend_buffer_is_vk(srci->buffer)) {
+            ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)srci->buffer->context;
             vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
-            uint64_t offset = vk_tensor_offset(src3) + src3->view_offs;
-            if (!ggml_is_contiguous(src3) && ggml_vk_dim01_contiguous(src3)) {
-                for (int i3 = 0; i3 < src3->ne[3]; i3++) {
-                    for (int i2 = 0; i2 < src3->ne[2]; i2++) {
-                        const int idx = i3*src3->ne[2] + i2;
-                        ggml_vk_buffer_read(buffer_gpu, offset + idx * src3->nb[2], ((char *)src3_clone->data + idx * src3_clone->nb[2]), src3->ne[1] * src3->nb[1]);
+            uint64_t offset = vk_tensor_offset(srci) + srci->view_offs;
+            if (!ggml_is_contiguous(srci) && ggml_vk_dim01_contiguous(srci)) {
+                for (int i3 = 0; i3 < srci->ne[3]; i3++) {
+                    for (int i2 = 0; i2 < srci->ne[2]; i2++) {
+                        const int idx = i3*srci->ne[2] + i2;
+                        ggml_vk_buffer_read(buffer_gpu, offset + idx * srci->nb[2], ((char *)srci_clone->data + idx * srci_clone->nb[2]), srci->ne[1] * srci->nb[1]);
                     }
                 }
 
-                src3_clone->nb[0] = src3->nb[0];
-                src3_clone->nb[1] = src3->nb[1];
+                srci_clone->nb[0] = srci->nb[0];
+                srci_clone->nb[1] = srci->nb[1];
                 for (int i = 2; i < GGML_MAX_DIMS; i++) {
-                    src3_clone->nb[i] = src3_clone->nb[i - 1]*src3_clone->ne[i - 1];
+                    srci_clone->nb[i] = srci_clone->nb[i - 1]*srci_clone->ne[i - 1];
                 }
             } else {
-                if (offset + src3_size >= buffer_gpu->size) {
-                    src3_size = buffer_gpu->size - offset;
+                if (offset + srci_size >= buffer_gpu->size) {
+                    srci_size = buffer_gpu->size - offset;
                 }
-                ggml_vk_buffer_read(buffer_gpu, offset, src3_clone->data, src3_size);
-                memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS);
+                ggml_vk_buffer_read(buffer_gpu, offset, srci_clone->data, srci_size);
+                memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS);
             }
         } else {
             GGML_ABORT("fatal error");
         }
 
         if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
-            ggml_vk_print_tensor(src3, "src3");
+            ggml_vk_print_tensor(srci, srci_name[i]);
         }
     }
 
     if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
         const float *params = (const float *)tensor->op_params;
-        tensor_clone = ggml_flash_attn_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, src3_clone, params[0], params[1], params[2]);
+        tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
     } else if (tensor->op == GGML_OP_MUL_MAT) {
-        tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone);
+        tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
     } else if (tensor->op == GGML_OP_MUL_MAT_ID) {
-        tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone);
+        tensor_clone = ggml_mul_mat_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
+    } else if (tensor->op == GGML_OP_SUB) {
+        tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
     } else if (tensor->op == GGML_OP_MUL) {
-        tensor_clone = ggml_mul(ggml_ctx, src0_clone, src1_clone);
+        tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
     } else if (tensor->op == GGML_OP_DIV) {
-        tensor_clone = ggml_div(ggml_ctx, src0_clone, src1_clone);
+        tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
     } else if (tensor->op == GGML_OP_CONCAT) {
-        tensor_clone = ggml_concat(ggml_ctx, src0_clone, src1_clone, *(int *)tensor->op_params);
+        tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params);
     } else if (tensor->op == GGML_OP_UPSCALE) {
-        tensor_clone = ggml_upscale_ext(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
+        tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
     } else if (tensor->op == GGML_OP_SCALE) {
-        tensor_clone = ggml_scale(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0]);
+        tensor_clone = ggml_scale(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0]);
     } else if (tensor->op == GGML_OP_SQR) {
-        tensor_clone = ggml_sqr(ggml_ctx, src0_clone);
+        tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
     } else if (tensor->op == GGML_OP_SIN) {
-        tensor_clone = ggml_sin(ggml_ctx, src0_clone);
+        tensor_clone = ggml_sin(ggml_ctx, src_clone[0]);
     } else if (tensor->op == GGML_OP_COS) {
-        tensor_clone = ggml_cos(ggml_ctx, src0_clone);
+        tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
     } else if (tensor->op == GGML_OP_CLAMP) {
-        tensor_clone = ggml_clamp(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
+        tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
     } else if (tensor->op == GGML_OP_PAD) {
-        tensor_clone = ggml_pad(ggml_ctx, src0_clone, tensor->ne[0] - src0_clone->ne[0], tensor->ne[1] - src0_clone->ne[1], tensor->ne[2] - src0_clone->ne[2], tensor->ne[3] - src0_clone->ne[3]);
+        tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]);
     } else if (tensor->op == GGML_OP_REPEAT) {
-        tensor_clone = ggml_repeat(ggml_ctx, src0_clone, tensor);
+        tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor);
+    } else if (tensor->op == GGML_OP_REPEAT_BACK) {
+        tensor_clone = ggml_repeat_back(ggml_ctx, src_clone[0], tensor);
     } else if (tensor->op == GGML_OP_ADD) {
-        tensor_clone = ggml_add(ggml_ctx, src0_clone, src1_clone);
+        tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);
     } else if (tensor->op == GGML_OP_ACC) {
-        tensor_clone = ggml_acc(ggml_ctx, src0_clone, src1_clone, tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
+        tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
     } else if (tensor->op == GGML_OP_NORM) {
-        tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
+        tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
     } else if (tensor->op == GGML_OP_GROUP_NORM) {
-        tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
+        tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
     } else if (tensor->op == GGML_OP_RMS_NORM) {
-        tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
+        tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
     } else if (tensor->op == GGML_OP_SOFT_MAX) {
         if (src1 != nullptr) {
-            tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
+            tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
         } else {
-            tensor_clone = ggml_soft_max(ggml_ctx, src0_clone);
+            tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
         }
     } else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
-        tensor_clone = ggml_diag_mask_inf(ggml_ctx, src0_clone, *(int *)tensor->op_params);
+        tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params);
     } else if (tensor->op == GGML_OP_ROPE) {
         const int n_dims      = ((int32_t *) tensor->op_params)[1];
         const int mode        = ((int32_t *) tensor->op_params)[2];
@@ -8860,26 +8997,26 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
         const float beta_slow       = ((float *) tensor->op_params)[10];
         if (mode & GGML_ROPE_TYPE_MROPE) {
             int32_t *sections = ((int32_t *) tensor->op_params) + 11;
-            tensor_clone = ggml_rope_multi(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+            tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
         } else {
-            tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+            tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
         }
     } else if (tensor->op == GGML_OP_UNARY) {
         switch (ggml_get_unary_op(tensor)) {
         case GGML_UNARY_OP_SILU:
-            tensor_clone = ggml_silu(ggml_ctx, src0_clone);
+            tensor_clone = ggml_silu(ggml_ctx, src_clone[0]);
             break;
         case GGML_UNARY_OP_GELU:
-            tensor_clone = ggml_gelu(ggml_ctx, src0_clone);
+            tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]);
             break;
         case GGML_UNARY_OP_GELU_QUICK:
-            tensor_clone = ggml_gelu_quick(ggml_ctx, src0_clone);
+            tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]);
             break;
         case GGML_UNARY_OP_RELU:
-            tensor_clone = ggml_relu(ggml_ctx, src0_clone);
+            tensor_clone = ggml_relu(ggml_ctx, src_clone[0]);
             break;
         case GGML_UNARY_OP_TANH:
-            tensor_clone = ggml_tanh(ggml_ctx, src0_clone);
+            tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]);
             break;
         default:
             std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
@@ -8887,28 +9024,34 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
         }
     } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
         if (src1 == nullptr) {
-            tensor_clone = ggml_dup(ggml_ctx, src0_clone);
+            tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
             tensor_clone->type = tensor->type;
         } else {
-            tensor_clone = ggml_cpy(ggml_ctx, src0_clone, src1_clone);
+            tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);
         }
     } else if (tensor->op == GGML_OP_CONT) {
-        tensor_clone = ggml_cont_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
+        tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
     } else if (tensor->op == GGML_OP_RESHAPE) {
-        tensor_clone = ggml_reshape_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
+        tensor_clone = ggml_reshape_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
     } else if (tensor->op == GGML_OP_VIEW) {
-        tensor_clone = ggml_view_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]);
+        tensor_clone = ggml_view_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]);
     } else if (tensor->op == GGML_OP_PERMUTE) {
         int32_t * params = (int32_t *)tensor->op_params;
-        tensor_clone = ggml_permute(ggml_ctx, src0_clone, params[0], params[1], params[2], params[3]);
+        tensor_clone = ggml_permute(ggml_ctx, src_clone[0], params[0], params[1], params[2], params[3]);
     } else if (tensor->op == GGML_OP_TRANSPOSE) {
-        tensor_clone = ggml_transpose(ggml_ctx, src0_clone);
+        tensor_clone = ggml_transpose(ggml_ctx, src_clone[0]);
     } else if (tensor->op == GGML_OP_GET_ROWS) {
-        tensor_clone = ggml_get_rows(ggml_ctx, src0_clone, src1_clone);
+        tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]);
     } else if (tensor->op == GGML_OP_ARGSORT) {
-        tensor_clone = ggml_argsort(ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params);
+        tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params);
+    } else if (tensor->op == GGML_OP_SUM) {
+        tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
     } else if (tensor->op == GGML_OP_SUM_ROWS) {
-        tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone);
+        tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
+    } else if (tensor->op == GGML_OP_ARGMAX) {
+        tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]);
+    } else if (tensor->op == GGML_OP_COUNT_EQUAL) {
+        tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]);
     } else if (tensor->op == GGML_OP_IM2COL) {
         const int32_t s0 = tensor->op_params[0];
         const int32_t s1 = tensor->op_params[1];
@@ -8918,11 +9061,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
         const int32_t d1 = tensor->op_params[5];
 
         const bool is_2D = tensor->op_params[6] == 1;
-        tensor_clone = ggml_im2col(ggml_ctx, src0_clone, src1_clone, s0, s1, p0, p1, d0, d1, is_2D, tensor->type);
+        tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type);
     } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) {
         const int32_t dim = tensor->op_params[0];
         const int32_t max_period = tensor->op_params[1];
-        tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period);
+        tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period);
     } else if (tensor->op == GGML_OP_POOL_2D) {
         enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
         const int32_t k0 = tensor->op_params[1];
@@ -8932,13 +9075,17 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
         const int32_t p0 = tensor->op_params[5];
         const int32_t p1 = tensor->op_params[6];
 
-        tensor_clone = ggml_pool_2d(ggml_ctx, src0_clone, op, k0, k1, s0, s1, p0, p1);
+        tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1);
     } else if (tensor->op == GGML_OP_LEAKY_RELU) {
         const float * op_params = (const float *)tensor->op_params;
-        tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false);
+        tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false);
     } else if (tensor->op == GGML_OP_RWKV_WKV6) {
-        tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
-        tensor->src[4], tensor->src[5]);
+        tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
+        src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
+    } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
+        src_clone[0]->flags = src0->flags;
+        tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
+        src_clone[2], src_clone[3], src_clone[4]);
     }
     else {
         std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
@@ -8960,11 +9107,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
     memcpy(comp_result, tensor_clone->data, comp_size);
     memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS);
 
-    if (src0 != nullptr) {
-        free(src0_buffer);
-    }
-    if (src1 != nullptr) {
-        free(src1_buffer);
+    for (int i = 0; i < 6; i++) {
+        if (src_buffer[i] != nullptr) {
+            free(src_buffer[i]);
+        }
     }
 
     ggml_free(ggml_ctx);
@@ -9028,6 +9174,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
                         } else if (tensor->type == GGML_TYPE_I32) {
                             correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
                             result  = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
+                        } else if (tensor->type == GGML_TYPE_I64) {
+                            correct = *(int64_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
+                            result  = *(int64_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
                         } else {
                             std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl;
                         }
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp
new file mode 100644 (file)
index 0000000..eaf4da3
--- /dev/null
@@ -0,0 +1,51 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 32;
+
+shared FLOAT_TYPE tmpmax[BLOCK_SIZE];
+shared uint tmp[BLOCK_SIZE];
+
+void main() {
+    const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
+    const uint col = gl_LocalInvocationID.x;
+
+    if (col >= p.KX) {
+        return;
+    }
+    A_TYPE amax = data_a[row*p.KX + col];
+    tmp[col] = col;
+
+    for (uint i = col + BLOCK_SIZE; i < p.KX; i += BLOCK_SIZE) {
+        A_TYPE val = data_a[row*p.KX + i];
+        if (val > amax) {
+            amax = val;
+            tmp[col] = i;
+        }
+    }
+    tmpmax[col] = amax;
+
+    barrier();
+    [[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) {
+        if (col < s && col + s < p.KX) {
+            if (tmpmax[col] < tmpmax[col + s]) {
+                tmpmax[col] = tmpmax[col + s];
+                tmp[col] = tmp[col + s];
+            }
+        }
+        barrier();
+    }
+
+    if (col == 0) {
+        data_d[row] = D_TYPE(tmp[0]);
+    }
+}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp b/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp
new file mode 100644 (file)
index 0000000..d934549
--- /dev/null
@@ -0,0 +1,31 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+
+#include "types.comp"
+#include "generic_head.comp"
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
+layout (binding = 2) buffer D {D_TYPE data_d[];};
+
+const uint CHUNK_SIZE = 512;
+
+void main() {
+    const uint base = gl_WorkGroupID.x * CHUNK_SIZE;
+    const uint col = gl_LocalInvocationID.x;
+
+    uint count = 0;
+    [[unroll]]
+    for (uint i = 0; i < CHUNK_SIZE; i += gl_WorkGroupSize.x) {
+        const uint idx = base + i + col;
+        if (idx >= p.KX) {
+            break;
+        }
+        count += uint(data_a[idx] == data_b[idx]);
+    }
+
+    atomicAdd(data_d[0], D_TYPE(count));
+}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp b/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp
new file mode 100644 (file)
index 0000000..e0214fe
--- /dev/null
@@ -0,0 +1,42 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) buffer X {A_TYPE x[];};
+layout (binding = 1) readonly buffer G {A_TYPE grad[];};
+layout (binding = 2) buffer GM {A_TYPE gradm[];};
+layout (binding = 3) buffer GV {A_TYPE gradv[];};
+layout (binding = 4) readonly buffer P {float params[7];};
+
+void main() {
+    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+    if (i >= p.KX) {
+        return;
+    }
+
+    const float alpha  = params[0];
+    const float beta1  = params[1];
+    const float beta2  = params[2];
+    const float eps    = params[3];
+    const float wd     = params[4];
+    const float beta1h = params[5];
+    const float beta2h = params[6];
+
+    const float gi = grad[i];
+    const float gmi = gradm[i]*beta1 +    gi*(1.0f - beta1);
+    const float gvi = gradv[i]*beta2 + gi*gi*(1.0f - beta2);
+
+    gradm[i] = gmi;
+    gradv[i] = gvi;
+
+    const float mh =      gmi*beta1h;
+    const float vh = sqrt(gvi*beta2h) + eps;
+
+    x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh;
+}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp
new file mode 100644 (file)
index 0000000..d862799
--- /dev/null
@@ -0,0 +1,37 @@
+#version 450
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+void main() {
+    const uint idx = get_idx();
+
+    if (idx >= p.ne) {
+        return;
+    }
+
+    // Destination multi-index (inlined dst_idx)
+    const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
+    const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
+    const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
+    const uint i12_offset = i12*p.ne11*p.ne10;
+    const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
+    const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
+    const uint d_idx = i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;
+
+    // Accumulate from sources
+    A_TYPE acc = A_TYPE(0);
+    for (uint i3 = i13; i3 < p.ne03; i3 += p.ne13) {
+        for (uint i2 = i12; i2 < p.ne02; i2 += p.ne12) {
+            for (uint i1 = i11; i1 < p.ne01; i1 += p.ne11) {
+                for (uint i0 = i10; i0 < p.ne00; i0 += p.ne10) {
+                    acc += data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00];
+                }
+            }
+        }
+    }
+
+    data_d[get_doffset() + d_idx] = D_TYPE(acc);
+}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp
new file mode 100644 (file)
index 0000000..72353cc
--- /dev/null
@@ -0,0 +1,29 @@
+#version 450
+
+#extension GL_EXT_shader_16bit_storage : require
+
+#include "types.comp"
+#include "generic_binary_head.comp"
+
+const uint num_threads = 256;
+
+layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
+
+void main() {
+    uint idx = get_idx();
+
+    // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
+    const uint num_iter = 2;
+
+    [[unroll]] for (uint i = 0; i < num_iter; ++i) {
+        if (idx >= p.ne) {
+            continue;
+        }
+        uint i00, i01, i02, i03;
+        get_indices(idx, i00, i01, i02, i03);
+
+        data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) - FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
+
+        idx += num_threads;
+    }
+}
index ba9163af27ad965f1e7499380a5bc35a40d143ec..3128c3d507a61f07276a6593e48d3d863ca6d898 100644 (file)
@@ -443,6 +443,8 @@ void process_shaders() {
     string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
     string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
 
+    string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+
     string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 
     string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
@@ -452,6 +454,7 @@ void process_shaders() {
     string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 
     string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+    string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
 
     string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 
@@ -501,7 +504,9 @@ void process_shaders() {
 
     string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
 
+    string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
     string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+    string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
 
     string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
     string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
@@ -513,6 +518,8 @@ void process_shaders() {
 
     string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
 
+    string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
+
     for (auto &c : compiles) {
         c.wait();
     }