]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: Handle src1 batch dimension in non-contiguous mat-vec-mul shader (#13191)
authorJeff Bolz <redacted>
Thu, 1 May 2025 18:19:31 +0000 (13:19 -0500)
committerGitHub <redacted>
Thu, 1 May 2025 18:19:31 +0000 (20:19 +0200)
* vulkan: Handle src1 batch dimension in non-contiguous mat-vec-mul shader

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp

index c0bdb9e17a7b498ad5177fd395d9371f60de21a7..4614c3c1563017cd79d20b507ec0b573253a2ec6 100644 (file)
@@ -2399,7 +2399,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
             ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len,              mul_mat_vec_p021_f16_f32_data,              "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
         }
     }
-    ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 9 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
@@ -4949,6 +4949,8 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
     const uint64_t nb01 = src0->nb[1];
     const uint64_t nb02 = src0->nb[2];
 
+    const uint64_t nb12 = src1->nb[2];
+
     // const uint64_t ne10 = src1->ne[0];
     const uint64_t ne11 = src1->ne[1];
     const uint64_t ne12 = src1->ne[2];
@@ -4974,6 +4976,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
 
     const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t);
     const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);
+    const uint32_t channel_stride_y = nb12 / sizeof(float);
 
     const uint64_t qx_sz = ggml_nbytes(src0);
     const uint64_t qy_sz = ggml_nbytes(src1);
@@ -5004,7 +5007,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
     const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
 
     // compute
-    const std::array<uint32_t, 7> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, (uint32_t)(ne12 / ne02), (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
+    const std::array<uint32_t, 9> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
     ggml_vk_sync_buffers(subctx);
     ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
         { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
index 48376637fb3e7c370da2af2808fc706be984013e..bc633369f9bb58dc4445cfddb8503706cb26babb 100644 (file)
@@ -21,7 +21,9 @@ layout (push_constant) uniform parameter
     uint nrows_x;
     uint row_stride_x;
     uint channel_stride_x;
+    uint channel_stride_y;
     uint channel_x_divisor;
+    uint ne12;
     uint b_offset;
     uint d_offset;
 } p;
@@ -33,6 +35,7 @@ void main() {
     const uint row_x     = gl_GlobalInvocationID.y;
     const uint channel   = gl_GlobalInvocationID.z;
     const uint channel_x = channel / p.channel_x_divisor;
+    const uint channel_y = channel % p.ne12;
 
     const uint nrows_y   = p.ncols_x;
     const uint nrows_dst = p.nrows_x;
@@ -56,7 +59,7 @@ void main() {
                 const uint row_y = col_x;
 
                 const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
-                const uint iy = channel*nrows_y + row_y;
+                const uint iy = channel_y*p.channel_stride_y + row_y;
 
                 const vec4 av4 = vec4(data_a_v4[ix / 4]);
                 const vec4 bv4 = vec4(data_b_v4[iy / 4]);
@@ -72,7 +75,7 @@ void main() {
             const uint row_y = col_x;
 
             const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
-            const uint iy = channel*nrows_y + row_y;
+            const uint iy = channel_y*p.channel_stride_y + row_y;
 
             const vec4 av4 = vec4(data_a_v4[ix / 4]);
             const vec4 bv4 = vec4(data_b_v4[iy / 4]);
@@ -89,7 +92,7 @@ void main() {
             const uint row_y = col_x;
 
             const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
-            const uint iy = channel*nrows_y + row_y;
+            const uint iy = channel_y*p.channel_stride_y + row_y;
 
             const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);