]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan : refactor buffer handling in vk_op_f32 (#16840)
authorAcly <redacted>
Fri, 7 Nov 2025 20:08:50 +0000 (21:08 +0100)
committerGitHub <redacted>
Fri, 7 Nov 2025 20:08:50 +0000 (21:08 +0100)
* vulkan : refactor/simplify buffer handling in vk_op_* functions

* Combine UMA handling into ggml_vk_tensor_subbuffer

ggml/src/ggml-vulkan/ggml-vulkan.cpp

index ab94bc3d78f682a04cded2e882ee3991741cb3dd..a0a05f2e5b2d0d3e37ffb34e2302f7e8e1a11029 100644 (file)
@@ -5387,7 +5387,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) {
     device->pinned_memory.erase(device->pinned_memory.begin() + index);
 }
 
-static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) {
+static void ggml_vk_host_get(const vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) {
     std::lock_guard<std::recursive_mutex> guard(device->mutex);
     buf = nullptr;
     buf_offset = 0;
@@ -5402,6 +5402,32 @@ static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf
     }
 }
 
+static vk_subbuffer ggml_vk_tensor_subbuffer(
+    const ggml_backend_vk_context * ctx, const ggml_tensor * tensor, bool allow_misalign = false) {
+
+    vk_buffer buffer = nullptr;
+    size_t offset = 0;
+    if (ctx->device->uma) {
+        ggml_vk_host_get(ctx->device, tensor->data, buffer, offset);
+    }
+    if (!buffer) {
+        auto buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
+        buffer = buf_ctx->dev_buffer;
+        offset = vk_tensor_offset(tensor) + tensor->view_offs;
+    }
+    GGML_ASSERT(buffer != nullptr);
+
+    size_t size = ggml_nbytes(tensor);
+
+    size_t misalign_bytes = offset & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
+    // The shader must support misaligned offsets when indexing into the buffer
+    GGML_ASSERT(allow_misalign || misalign_bytes == 0);
+    offset &= ~misalign_bytes;
+    size += misalign_bytes;
+
+    return vk_subbuffer{buffer, offset, size};
+}
+
 static vk_submission ggml_vk_begin_submission(vk_device& device, vk_command_pool& p, bool one_time = true) {
     vk_submission s;
     s.buffer = ggml_vk_create_cmd_buffer(device, p);
@@ -7953,72 +7979,12 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
-    vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr, d_S = nullptr;
-    size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0, s_buf_offset = 0;
-
-    bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false, S_uma = false;
-
-    if (ctx->device->uma) {
-        ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
-        ggml_vk_host_get(ctx->device, k->data, d_K, k_buf_offset);
-        ggml_vk_host_get(ctx->device, v->data, d_V, v_buf_offset);
-        ggml_vk_host_get(ctx->device, dst->data, d_D, d_buf_offset);
-        Q_uma = d_Q != nullptr;
-        K_uma = d_K != nullptr;
-        V_uma = d_V != nullptr;
-        D_uma = d_D != nullptr;
-        if (mask) {
-            ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset);
-            M_uma = d_M != nullptr;
-        }
-        if (sinks) {
-            ggml_vk_host_get(ctx->device, sinks->data, d_S, s_buf_offset);
-            S_uma = d_S != nullptr;
-        }
-    }
-
-
-    ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
-    ggml_backend_vk_buffer_context * q_buf_ctx = (ggml_backend_vk_buffer_context *)q->buffer->context;
-    ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
-    ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
-
-    if (!Q_uma) {
-        d_Q = q_buf_ctx->dev_buffer;
-        q_buf_offset = vk_tensor_offset(q) + q->view_offs;
-    }
-    if (!K_uma) {
-        d_K = k_buf_ctx->dev_buffer;
-        k_buf_offset = vk_tensor_offset(k) + k->view_offs;
-    }
-    if (!V_uma) {
-        d_V = v_buf_ctx->dev_buffer;
-        v_buf_offset = vk_tensor_offset(v) + v->view_offs;
-    }
-    if (!D_uma) {
-        d_D = d_buf_ctx->dev_buffer;
-        d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
-    }
-
-    if (!M_uma) {
-        d_M = d_Q;
-        m_buf_offset = q_buf_offset;
-        if (mask) {
-            ggml_backend_vk_buffer_context * m_buf_ctx = (ggml_backend_vk_buffer_context*)mask->buffer->context;
-            d_M = m_buf_ctx->dev_buffer;
-            m_buf_offset = vk_tensor_offset(mask) + mask->view_offs;
-        }
-    }
-
-    if (!S_uma) {
-        d_S = d_Q;
-        s_buf_offset = q_buf_offset;
-        if (sinks) {
-            ggml_backend_vk_buffer_context * s_buf_ctx = (ggml_backend_vk_buffer_context*)sinks->buffer->context;
-            d_S = s_buf_ctx->dev_buffer;
-            s_buf_offset = vk_tensor_offset(sinks) + sinks->view_offs;
-        }
-    }
+    vk_subbuffer q_buf = ggml_vk_tensor_subbuffer(ctx, q);
+    vk_subbuffer k_buf = ggml_vk_tensor_subbuffer(ctx, k);
+    vk_subbuffer v_buf = ggml_vk_tensor_subbuffer(ctx, v);
+    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
+    vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer(ctx, mask) : q_buf;
+    vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf;
 
     uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2;
 
@@ -8040,15 +8006,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
             ggml_vk_sync_buffers(ctx, subctx);
         }
 
+        vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
         ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
-                                    {
-                                        ggml_vk_subbuffer(ctx, d_Q, q_buf_offset),
-                                        ggml_vk_subbuffer(ctx, d_K, k_buf_offset),
-                                        ggml_vk_subbuffer(ctx, d_V, v_buf_offset),
-                                        ggml_vk_subbuffer(ctx, d_M, m_buf_offset),
-                                        ggml_vk_subbuffer(ctx, d_S, s_buf_offset),
-                                        ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0),
-                                    },
+                                    {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf},
                                     // We only use split_k when group query attention is enabled, which means
                                     // there's no more than one tile of rows (i.e. workgroups_x would have been
                                     // one). We reuse workgroups_x to mean the number of splits, so we need to
@@ -8058,23 +8018,12 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
         ggml_vk_sync_buffers(ctx, subctx);
         const std::array<uint32_t, 5> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) };
         ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
-                                    {
-                                        ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0),
-                                        ggml_vk_subbuffer(ctx, d_S, s_buf_offset),
-                                        ggml_vk_subbuffer(ctx, d_D, d_buf_offset),
-                                    },
+                                    {split_k_buf, sinks_buf, dst_buf},
                                     pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
         ctx->prealloc_split_k_need_sync = true;
     } else {
         ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
-                                    {
-                                        ggml_vk_subbuffer(ctx, d_Q, q_buf_offset),
-                                        ggml_vk_subbuffer(ctx, d_K, k_buf_offset),
-                                        ggml_vk_subbuffer(ctx, d_V, v_buf_offset),
-                                        ggml_vk_subbuffer(ctx, d_M, m_buf_offset),
-                                        ggml_vk_subbuffer(ctx, d_S, s_buf_offset),
-                                        ggml_vk_subbuffer(ctx, d_D, d_buf_offset),
-                                    },
+                                    {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf},
                                     pc, { workgroups_x, workgroups_y, workgroups_z });
     }
 }
@@ -8757,35 +8706,15 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
     const uint64_t ne01 = src0->ne[1];
     const uint64_t ne02 = src0->ne[2];
     const uint64_t ne03 = src0->ne[3];
-    const uint64_t ne0 = ne00 * ne01;
 
     const bool use_src1 = src1 != nullptr;
     const uint64_t ne10 = use_src1 ? src1->ne[0] : 0;
     const uint64_t ne11 = use_src1 ? src1->ne[1] : 0;
     const uint64_t ne12 = use_src1 ? src1->ne[2] : 0;
     const uint64_t ne13 = use_src1 ? src1->ne[3] : 0;
-    const uint64_t ne1 = ne10 * ne11;
-    // const uint64_t nb10 = use_src1 ? src1->nb[0] : 0;
 
     const bool use_src2 = src2 != nullptr;
-    const uint64_t ne20 = use_src2 ? src2->ne[0] : 0;
-    const uint64_t ne21 = use_src2 ? src2->ne[1] : 0;
-    const uint64_t ne22 = use_src2 ? src2->ne[2] : 0;
-    const uint64_t ne23 = use_src2 ? src2->ne[3] : 0;
-    const uint64_t ne2 = ne20 * ne21;
-
     const bool use_src3 = src3 != nullptr;
-    const uint64_t ne30 = use_src3 ? src3->ne[0] : 0;
-    const uint64_t ne31 = use_src3 ? src3->ne[1] : 0;
-    const uint64_t ne32 = use_src3 ? src3->ne[2] : 0;
-    const uint64_t ne33 = use_src3 ? src3->ne[3] : 0;
-    const uint64_t ne3 = ne30 * ne31;
-
-    const uint64_t ned0 = dst->ne[0];
-    const uint64_t ned1 = dst->ne[1];
-    const uint64_t ned2 = dst->ne[2];
-    const uint64_t ned3 = dst->ne[3];
-    const uint64_t ned = ned0 * ned1;
 
     init_pushconst_fastdiv(pc);
 
@@ -8804,74 +8733,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
 
     const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op);
 
-    ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
-    ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
-    ggml_backend_vk_buffer_context * src1_buf_ctx = use_src1 ? (ggml_backend_vk_buffer_context *)src1->buffer->context : nullptr;
-    ggml_backend_vk_buffer_context * src2_buf_ctx = use_src2 ? (ggml_backend_vk_buffer_context *)src2->buffer->context : nullptr;
-    ggml_backend_vk_buffer_context * src3_buf_ctx = use_src3 ? (ggml_backend_vk_buffer_context *)src3->buffer->context : nullptr;
-
-    vk_buffer d_X = nullptr;
-    size_t x_buf_offset = 0;
-    vk_buffer d_Y = nullptr;
-    size_t y_buf_offset = 0;
-    vk_buffer d_Z = nullptr;
-    size_t z_buf_offset = 0;
-    vk_buffer d_W = nullptr;
-    size_t w_buf_offset = 0;
-
-    bool src0_uma = false;
-    bool src1_uma = false;
-    bool src2_uma = false;
-    bool src3_uma = false;
+    vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0, op_supports_incontiguous);
+    vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, src1, op_supports_incontiguous) : vk_subbuffer{};
+    vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, src2, op_supports_incontiguous) : vk_subbuffer{};
+    vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, src3, op_supports_incontiguous) : vk_subbuffer{};
+    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, op_supports_incontiguous);
 
-    if (ctx->device->uma) {
-        ggml_vk_host_get(ctx->device, src0->data, d_X, x_buf_offset);
-        src0_uma = d_X != nullptr;
-        if (use_src1) {
-            ggml_vk_host_get(ctx->device, src1->data, d_Y, y_buf_offset);
-            src1_uma = d_Y != nullptr;
-        }
-        if (use_src2) {
-            ggml_vk_host_get(ctx->device, src2->data, d_Z, z_buf_offset);
-            src2_uma = d_Z != nullptr;
-        }
-        if (use_src3) {
-            ggml_vk_host_get(ctx->device, src3->data, d_W, w_buf_offset);
-            src3_uma = d_W != nullptr;
-        }
-    }
-
-    vk_buffer d_D = dst_buf_ctx->dev_buffer;
-
-    GGML_ASSERT(d_D != nullptr);
-    uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
-    if(!src0_uma) {
-        d_X = src0_buf_ctx->dev_buffer;
-        x_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
-        GGML_ASSERT(d_X != nullptr);
-    }
-    if (use_src1 && !src1_uma) {
-        d_Y = src1_buf_ctx->dev_buffer;
-        y_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
-        GGML_ASSERT(d_Y != nullptr);
-    }
-    if (use_src2 && !src2_uma) {
-        d_Z = src2_buf_ctx->dev_buffer;
-        z_buf_offset = vk_tensor_offset(src2) + src2->view_offs;
-        GGML_ASSERT(d_Z != nullptr);
-    }
-    if (use_src3 && !src3_uma) {
-        d_W = src3_buf_ctx->dev_buffer;
-        w_buf_offset = vk_tensor_offset(src3) + src3->view_offs;
-        GGML_ASSERT(d_W != nullptr);
-    }
-    // Compute misalignment offset for descriptors and store it in in push constants, then align the descriptor offsets.
+    // Compute misalignment offset for descriptors and store it in in push constants.
     init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, src3, dst);
-    x_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
-    y_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
-    z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
-    w_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
-    d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
 
     std::array<uint32_t, 3> elements;
 
@@ -8955,9 +8824,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
             const uint32_t KH = ne01;
             const uint32_t KW = ne00;
 
-            const uint32_t OD = ned3 / N;
-            const uint32_t OH = ned2;
-            const uint32_t OW = ned1;
+            const uint32_t OD = dst->ne[3] / N;
+            const uint32_t OH = dst->ne[2];
+            const uint32_t OW = dst->ne[1];
 
             const uint32_t IC_KD_KH_KW = IC*KD*KH*KW;
             const uint32_t N_OD_OH = N*OD*OH;
@@ -9072,112 +8941,50 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
         break;
     }
 
-    uint64_t x_sz, y_sz, z_sz, w_sz, d_sz;
-
-    if (op_supports_incontiguous) {
-        x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0);
-        y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0;
-        z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0;
-        w_sz = use_src3 ? ggml_nbytes(src3) + get_misalign_bytes(ctx, src3) : 0;
-        d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst);
-
-        if (x_buf_offset + x_sz >= d_X->size) {
-            x_sz = ggml_vk_get_max_buffer_range(ctx, d_X, x_buf_offset);
-        }
-        if (use_src1 && y_buf_offset + y_sz >= d_Y->size) {
-            y_sz = ggml_vk_get_max_buffer_range(ctx, d_Y, y_buf_offset);
-        }
-        if (use_src2 && z_buf_offset + z_sz >= d_Z->size) {
-            z_sz = ggml_vk_get_max_buffer_range(ctx, d_Z, z_buf_offset);
-        }
-        if (use_src3 && w_buf_offset + w_sz >= d_W->size) {
-            w_sz = ggml_vk_get_max_buffer_range(ctx, d_W, w_buf_offset);
-        }
-        if (d_buf_offset + d_sz >= d_D->size) {
-            d_sz = ggml_vk_get_max_buffer_range(ctx, d_D, d_buf_offset);
-        }
-    } else {
-        x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0 * ne02 * ne03;
-        y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 * ne12 * ne13 : 0;
-        z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 * ne22 * ne23 : 0;
-        w_sz = use_src3 ? ggml_type_size(src3->type) * ne3 * ne32 * ne33 : 0;
-        d_sz = ggml_type_size(dst->type) * ned * ned2 * ned3;
-    }
-
     if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {
-        vk_buffer d_A = ctx->do_add_rms_partials ? ctx->prealloc_add_rms_partials : d_X;
-        size_t a_buf_offset = ctx->do_add_rms_partials ? ctx->prealloc_size_add_rms_partials_offset : 0;
+        vk_subbuffer a_buf = src0_buf;
+        if (ctx->do_add_rms_partials) {
+            a_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_add_rms_partials, ctx->prealloc_size_add_rms_partials_offset);
+        }
         ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
-            { vk_subbuffer{ d_X, x_buf_offset, x_sz },
-              vk_subbuffer{ d_Y, y_buf_offset, y_sz },
-              vk_subbuffer{ d_D, d_buf_offset, d_sz },
-              ggml_vk_subbuffer(ctx, d_A, a_buf_offset),
-            }, pc, elements);
+            { src0_buf, src1_buf, dst_buf, a_buf }, pc, elements);
     } else if (op == GGML_OP_GLU) {
         // Empty src1 is possible in glu, but the shader needs a buffer
-        vk_subbuffer subbuf_y;
-        if (use_src1) {
-            subbuf_y = { d_Y, y_buf_offset, y_sz };
-        } else {
-            subbuf_y = { d_X, 0, x_sz };
-        }
-
-        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
+        vk_subbuffer subbuf1 = use_src1 ? src1_buf : src0_buf;
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc, elements);
     } else if (op == GGML_OP_SOFT_MAX) {
         // Empty src1 and src2 is possible in soft_max, but the shader needs a buffer
-        vk_subbuffer subbuf_y;
-        if (use_src1) {
-            subbuf_y = { d_Y, y_buf_offset, y_sz };
-        } else {
-            subbuf_y = { d_X, 0, x_sz };
-        }
-
-        vk_subbuffer subbuf_z;
-        if (use_src2) {
-            subbuf_z = { d_Z, z_buf_offset, z_sz };
-        } else {
-            subbuf_z = { d_X, 0, x_sz };
-        }
-
-        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
+        vk_subbuffer subbuf1 = use_src1 ? src1_buf : src0_buf;
+        vk_subbuffer subbuf2 = use_src2 ? src2_buf : src0_buf;
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, subbuf2, dst_buf }, pc, elements);
     } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
-        // Empty src2 is possible in rope, but the shader needs a buffer
-        vk_subbuffer subbuf_z, subbuf_w;
-        if (use_src2) {
-            subbuf_z = { d_Z, z_buf_offset, z_sz };
-        } else {
-            subbuf_z = { d_X, 0, x_sz };
-        }
-        if (use_src3) {
-            subbuf_w = { d_W, w_buf_offset, w_sz };
-        } else {
-            subbuf_w = { d_X, 0, x_sz };
-        }
-
-        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz }, subbuf_w }, pc, elements);
+        // Empty src2 and src3 is possible in rope, but the shader needs a buffer
+        vk_subbuffer subbuf2 = use_src2 ? src2_buf : src0_buf;
+        vk_subbuffer subbuf3 = use_src3 ? src3_buf : src0_buf;
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, subbuf2, dst_buf, subbuf3 }, pc, elements);
     } else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) {
         if (ctx->device->shader_int64 && ctx->device->buffer_device_address) {
             // buffer device address path doesn't use dst buffer
-            d_sz = 1;
+            dst_buf.size = 1;
         }
         // im2col uses only src1 and dst buffers
-        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src1_buf, dst_buf }, pc, elements);
     } else if (op == GGML_OP_COUNT_EQUAL) {
         // count_equal assumes that destination buffer is initialized with zeroes
-        ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
+        ggml_vk_buffer_memset_async(subctx, dst_buf.buffer, dst_buf.offset, 0, dst_buf.size);
         ggml_vk_sync_buffers(ctx, subctx);
-        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, dst_buf }, pc, elements);
     } else if (op == GGML_OP_OPT_STEP_SGD) {
         // OPT_STEP_SGD works on src0, it does not need dst
-        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements);
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf }, pc, elements);
     } else if (use_src3) {
-        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_W, w_buf_offset, w_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf, src3_buf, dst_buf }, pc, elements);
     } else if (use_src2) {
-        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf, dst_buf }, pc, elements);
     } else if (use_src1) {
-        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, dst_buf }, pc, elements);
     } else {
-        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, dst_buf }, pc, elements);
     }
 }
 
@@ -9413,39 +9220,10 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx
 
     ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
 
-    ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
-    ggml_backend_vk_buffer_context * src_buf_ctxs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
-    for (int i = 0; i < num_srcs; i++) {
-        src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
-    }
-
-    vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
-    size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 };
-    bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false };
-
-    if (ctx->device->uma) {
-        for (int i = 0; i < num_srcs; i++) {
-            ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
-            srcs_uma[i] = d_srcs[i] != nullptr;
-        }
-
-        ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
-        dst_uma = d_D != nullptr;
-    }
-
-    uint64_t src_sizes[7] = { 0, 0, 0, 0, 0, 0, 0 };
+    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
+    vk_subbuffer src_buf[7] = {};
     for (int i = 0; i < num_srcs; i++) {
-        src_sizes[i] = ggml_nbytes(dst->src[i]);
-        if (!srcs_uma[i]) {
-            d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
-            src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
-        }
-    }
-
-    const uint64_t dst_size = ggml_nbytes(dst);
-    if (!dst_uma) {
-        d_D = dst_buf_ctx->dev_buffer;
-        dst_offset = vk_tensor_offset(dst) + dst->view_offs;
+        src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);
     }
 
     std::array<uint32_t, 3> elements = {
@@ -9455,26 +9233,13 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx
     };
 
     if (version == 6) {
-        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
-            vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
-            vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
-            vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
-            vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
-            vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
-            vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
-            vk_subbuffer{ d_D, dst_offset, dst_size }
-        }, pc, elements);
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
+            {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
+            pc, elements);
     } else if (version == 7) {
-        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
-            vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
-            vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
-            vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
-            vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
-            vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
-            vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
-            vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
-            vk_subbuffer{ d_D, dst_offset, dst_size }
-        }, pc, elements);
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
+            {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf},
+            pc, elements);
     } else {
         // shouldn't happen
         GGML_ASSERT(false);
@@ -9554,40 +9319,10 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx,
         n_head, head_dim, n_group, n_tok
     };
 
-    ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
-    ggml_backend_vk_buffer_context * src_buf_ctxs[GGML_MAX_SRC];
-    for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
-        src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
-    }
-
-    vk_buffer d_D = nullptr, d_srcs[GGML_MAX_SRC] = { nullptr };
-    size_t dst_offset = 0, src_offsets[GGML_MAX_SRC] = { 0 };
-    bool dst_uma = false, srcs_uma[GGML_MAX_SRC] = { false };
-
-    if (ctx->device->uma) {
-        for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
-            ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
-            srcs_uma[i] = d_srcs[i] != nullptr;
-        }
-        ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
-        dst_uma = d_D != nullptr;
-    }
-
-    if (!dst_uma) {
-        d_D = dst_buf_ctx->dev_buffer;
-        dst_offset = vk_tensor_offset(dst) + dst->view_offs;
-    }
-    for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
-        if (!srcs_uma[i]) {
-            d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
-            src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
-        }
-    }
-
-    size_t dst_size = ggml_nbytes(dst);
-    size_t src_sizes[GGML_MAX_SRC];
-    for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
-        src_sizes[i] = ggml_nbytes(dst->src[i]);
+    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
+    vk_subbuffer src_buf[7] = {};
+    for (int i = 0; i < 7 && dst->src[i] != nullptr; i++) {
+        src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);
     }
 
     std::array<uint32_t, 3> elements;
@@ -9597,16 +9332,9 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx,
     const uint32_t num_workgroups_y = n_seq;
     elements = { num_workgroups_x, num_workgroups_y, 1 };
 
-    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
-        vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
-        vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
-        vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
-        vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
-        vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
-        vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
-        vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
-        vk_subbuffer{ d_D, dst_offset, dst_size }
-    }, pc, elements);
+    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
+        {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf},
+        pc, elements);
 }
 
 static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
@@ -9653,66 +9381,17 @@ static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_cont
 
     ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
 
-    ggml_backend_vk_buffer_context * x_buf_ctx = (ggml_backend_vk_buffer_context *)x->buffer->context;
-    ggml_backend_vk_buffer_context * g_buf_ctx = (ggml_backend_vk_buffer_context *)g->buffer->context;
-    ggml_backend_vk_buffer_context * gm_buf_ctx = (ggml_backend_vk_buffer_context *)gm->buffer->context;
-    ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer->context;
-    ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer->context;
-
-    vk_buffer d_X = nullptr, d_G = nullptr, d_GM = nullptr, d_GV = nullptr, d_P = nullptr;
-    size_t x_offset = 0, g_offset = 0, gm_offset = 0, gv_offset = 0, p_offset = 0;
-    bool X_uma = false, G_uma = false, GM_uma = false, GV_uma = false, P_uma = false;
-
-    if (ctx->device->uma) {
-        ggml_vk_host_get(ctx->device, x->data, d_X, x_offset);
-        ggml_vk_host_get(ctx->device, g->data, d_G, g_offset);
-        ggml_vk_host_get(ctx->device, gm->data, d_GM, gm_offset);
-        ggml_vk_host_get(ctx->device, gv->data, d_GV, gv_offset);
-        ggml_vk_host_get(ctx->device, p->data, d_P, p_offset);
-
-        X_uma = d_X != nullptr;
-        G_uma = d_G != nullptr;
-        GM_uma = d_GM != nullptr;
-        GV_uma = d_GV != nullptr;
-        P_uma = d_P != nullptr;
-    }
-
-    if (!X_uma) {
-        d_X = x_buf_ctx->dev_buffer;
-        x_offset = vk_tensor_offset(x) + x->view_offs;
-    }
-    if (!G_uma) {
-        d_G = g_buf_ctx->dev_buffer;
-        g_offset = vk_tensor_offset(g) + g->view_offs;
-    }
-    if (!GM_uma) {
-        d_GM = gm_buf_ctx->dev_buffer;
-        gm_offset = vk_tensor_offset(gm) + gm->view_offs;
-    }
-    if (!GV_uma) {
-        d_GV = gv_buf_ctx->dev_buffer;
-        gv_offset = vk_tensor_offset(gv) + gv->view_offs;
-    }
-    if (!P_uma) {
-        d_P = p_buf_ctx->dev_buffer;
-        p_offset = vk_tensor_offset(p) + p->view_offs;
-    }
-
-    const uint64_t x_size = ggml_nbytes(x);
-    const uint64_t g_size = ggml_nbytes(g);
-    const uint64_t gm_size = ggml_nbytes(gm);
-    const uint64_t gv_size = ggml_nbytes(gv);
-    const uint64_t p_size = ggml_nbytes(p);
+    vk_subbuffer x_buf = ggml_vk_tensor_subbuffer(ctx, x);
+    vk_subbuffer g_buf = ggml_vk_tensor_subbuffer(ctx, g);
+    vk_subbuffer gm_buf = ggml_vk_tensor_subbuffer(ctx, gm);
+    vk_subbuffer gv_buf = ggml_vk_tensor_subbuffer(ctx, gv);
+    vk_subbuffer p_buf = ggml_vk_tensor_subbuffer(ctx, p);
 
     std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(x), 1, 1 };
 
-    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
-        vk_subbuffer{ d_X, x_offset, x_size },
-        vk_subbuffer{ d_G, g_offset, g_size },
-        vk_subbuffer{ d_GM, gm_offset, gm_size },
-        vk_subbuffer{ d_GV, gv_offset, gv_size },
-        vk_subbuffer{ d_P, p_offset, p_size },
-    }, pc, elements);
+    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
+        {x_buf, g_buf, gm_buf, gv_buf, p_buf},
+        pc, elements);
 }
 
 static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
@@ -10044,45 +9723,9 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
 
     ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
 
-    ggml_backend_vk_buffer_context * logits_buf_ctx = (ggml_backend_vk_buffer_context *)logits->buffer->context;
-    ggml_backend_vk_buffer_context * weights_buf_ctx = (ggml_backend_vk_buffer_context *)weights->buffer->context;
-    ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;
-
-    vk_buffer d_logits = nullptr;
-    size_t logits_buf_offset = 0;
-    vk_buffer d_weights = nullptr;
-    size_t weights_buf_offset = 0;
-    vk_buffer d_ids = nullptr;
-    size_t ids_buf_offset = 0;
-
-    bool logits_uma = false;
-    bool weights_uma = false;
-    bool ids_uma = false;
-
-    if (ctx->device->uma) {
-        ggml_vk_host_get(ctx->device, logits->data, d_logits, logits_buf_offset);
-        ggml_vk_host_get(ctx->device, weights->data, d_weights, weights_buf_offset);
-        ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset);
-        logits_uma = d_logits != nullptr;
-        weights_uma = d_weights != nullptr;
-        ids_uma = d_ids != nullptr;
-    }
-
-    if (!logits_uma) {
-        d_logits = logits_buf_ctx->dev_buffer;
-        logits_buf_offset = vk_tensor_offset(logits) + logits->view_offs;
-        GGML_ASSERT(d_logits != nullptr);
-    }
-    if (!weights_uma) {
-        d_weights = weights_buf_ctx->dev_buffer;
-        weights_buf_offset = vk_tensor_offset(weights) + weights->view_offs;
-        GGML_ASSERT(d_weights != nullptr);
-    }
-    if (!ids_uma) {
-        d_ids = ids_buf_ctx->dev_buffer;
-        ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs;
-        GGML_ASSERT(d_ids != nullptr);
-    }
+    vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer(ctx, logits);
+    vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer(ctx, weights);
+    vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer(ctx, ids);
 
     vk_op_topk_moe_push_constants pc {};
     pc.n_rows = n_rows;
@@ -10098,12 +9741,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
     const uint32_t rows_per_block = 4;
     std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 };
 
-    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
-        {
-            ggml_vk_subbuffer(ctx, d_logits, logits_buf_offset),
-            ggml_vk_subbuffer(ctx, d_weights, weights_buf_offset),
-            ggml_vk_subbuffer(ctx, d_ids, ids_buf_offset),
-        }, pc, elements);
+    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, weights_buf, ids_buf}, pc, elements);
 }
 
 static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop) {