]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: Optimize mul_mat_vec p021 and nc shaders (#12505)
authorJeff Bolz <redacted>
Sat, 22 Mar 2025 08:40:11 +0000 (03:40 -0500)
committerGitHub <redacted>
Sat, 22 Mar 2025 08:40:11 +0000 (09:40 +0100)
* tests: add mul_mat perf/functional tests for p021/nc vulkan shaders

* vulkan: Optimize mul_mat_vec p021 and nc shaders.

These shaders are used in attention calculations, and when the KV cache grows
large they start to dominate the run time. For the nc shader (which is called
with large 'k' dimension), use unrolling and vector loads. For the p021 shader
(which is called with large 'm' and small 'k' dimensions), take advantage of
grouped query attention to reuse loads from the A matrix for the whole group,
and reduce the number of workgroups (too much overhead from tiny dispatches).

Using subgroupAdd in the p021 shader also helps, use that conditionally.

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
tests/test-backend-ops.cpp

index 649504566ab58819dc4093e03ec02c4d822a20fd..37fa8eec599a3ec6f59c5c6eba0e1d63f49bdbba 100644 (file)
@@ -149,6 +149,7 @@ class vk_perf_logger;
 static void ggml_vk_destroy_buffer(vk_buffer& buf);
 
 static constexpr uint32_t mul_mat_vec_max_cols = 8;
+static constexpr uint32_t p021_max_gqa_ratio = 8;
 
 enum vk_device_architecture {
     OTHER,
@@ -231,6 +232,7 @@ struct vk_device_struct {
     bool uma;
     bool prefer_host_memory;
     bool float_controls_rte_fp16;
+    bool subgroup_add;
 
     bool subgroup_size_control;
     uint32_t subgroup_min_size;
@@ -277,7 +279,7 @@ struct vk_device_struct {
     vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
     vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
 
-    vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
+    vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio];
     vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
     vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
     vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
@@ -2265,7 +2267,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
+    for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
+        if (device->subgroup_add && device->subgroup_require_full_support) {
+            ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true);
+        } else {
+            ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len,              mul_mat_vec_p021_f16_f32_data,              "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
+        }
+    }
     ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
@@ -2479,13 +2487,15 @@ static vk_device ggml_vk_get_device(size_t idx) {
         vk::PhysicalDeviceDriverProperties driver_props;
         vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
         vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
+        vk::PhysicalDeviceVulkan11Properties vk11_props;
         vk::PhysicalDeviceVulkan12Properties vk12_props;
         vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
 
         props2.pNext = &props3;
         props3.pNext = &subgroup_props;
         subgroup_props.pNext = &driver_props;
-        driver_props.pNext = &vk12_props;
+        driver_props.pNext = &vk11_props;
+        vk11_props.pNext = &vk12_props;
 
         VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props;
 
@@ -2549,6 +2559,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
         }
         device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
 
+        device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
+                               (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
+
         const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
 
         device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
@@ -4635,9 +4648,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
     const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
     const uint64_t d_sz = sizeof(float) * d_ne;
 
+    // With grouped query attention there are > 1 Q matrices per K, V matrix.
+    uint32_t gqa_ratio = (uint32_t)ne12 / (uint32_t)ne02;
+    if (gqa_ratio > 8 || gqa_ratio == 0 || ne12 != ne02 * gqa_ratio) {
+        gqa_ratio = 1;
+    }
+
     if (dryrun) {
         // Request descriptor sets
-        ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1);
+        ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1);
         return;
     }
 
@@ -4661,8 +4680,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
 
     // compute
     const std::array<uint32_t, 6> pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
+
+    uint32_t workgroups_z = (uint32_t)ne12;
+    // When gqa_ratio > 1, each invocation does multiple rows and we can launch fewer workgroups
+    if (gqa_ratio > 1) {
+        workgroups_z /= gqa_ratio;
+    }
+
     ggml_vk_sync_buffers(subctx);
-    ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
+    ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, workgroups_z });
 }
 
 static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
index 1cc4996d393a2a33d67cc0be1b549ca72c721876..48376637fb3e7c370da2af2808fc706be984013e 100644 (file)
@@ -12,6 +12,9 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
 layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
 layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
 
+layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
+layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
+
 layout (push_constant) uniform parameter
 {
     uint ncols_x;
@@ -37,25 +40,66 @@ void main() {
 
     const uint idst = channel*nrows_dst + row_dst;
 
-    tmp[tid] = 0.0f;
+    FLOAT_TYPE temp = 0.0f;
 
-    for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
-        const uint col_x = col_x0 + tid;
+    // Detect alignment for vector loads
+    bool is_aligned = (p.ncols_x % 4) == 0 && (p.row_stride_x % 4) == 0 && (p.channel_stride_x % 4) == 0;
 
-        if (col_x >= p.ncols_x) {
-            break;
-        }
+    for (uint col_x0 = 0; col_x0 < p.ncols_x;) {
+
+        // Unroll 2x and do vec4 loads if aligned
+        const uint unroll_count = 2;
+        if (col_x0 + unroll_count * 4 * BLOCK_SIZE <= p.ncols_x && is_aligned) {
+            [[unroll]] for (uint i = 0; i < unroll_count; ++i) {
+                const uint col_x = col_x0 + 4*tid;
+
+                const uint row_y = col_x;
+
+                const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
+                const uint iy = channel*nrows_y + row_y;
+
+                const vec4 av4 = vec4(data_a_v4[ix / 4]);
+                const vec4 bv4 = vec4(data_b_v4[iy / 4]);
+
+                temp += dot(av4, bv4);
+
+                col_x0 += 4*BLOCK_SIZE;
+            }
+        // do vec4 loads if aligned
+        } else if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) {
+            const uint col_x = col_x0 + 4*tid;
 
-        const uint row_y = col_x;
+            const uint row_y = col_x;
 
-        const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
-        const uint iy = channel*nrows_y + row_y;
+            const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
+            const uint iy = channel*nrows_y + row_y;
 
-        const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
+            const vec4 av4 = vec4(data_a_v4[ix / 4]);
+            const vec4 bv4 = vec4(data_b_v4[iy / 4]);
 
-        tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]);
+            temp += dot(av4, bv4);
+
+            col_x0 += 4*BLOCK_SIZE;
+        } else {
+            const uint col_x = col_x0 + tid;
+            if (col_x >= p.ncols_x) {
+                break;
+            }
+
+            const uint row_y = col_x;
+
+            const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
+            const uint iy = channel*nrows_y + row_y;
+
+            const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
+
+            temp = fma(xi, FLOAT_TYPE(data_b[iy]), temp);
+            col_x0 += BLOCK_SIZE;
+        }
     }
 
+    tmp[tid] = temp;
+
     // sum up partial sums and write back result
     barrier();
     [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
index 9b443807d87817c170d64f03ffa4bba406894972..7aa070eebdf72b4849468098de1a3a30238cf218 100644 (file)
@@ -2,16 +2,25 @@
 
 #extension GL_EXT_control_flow_attributes : enable
 #extension GL_EXT_shader_16bit_storage : require
+#if USE_SUBGROUP_ADD
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+#endif
 
-#define BLOCK_SIZE 32
 #define FLOAT_TYPE float
 
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
 layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
 layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
 layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
 
+layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
+layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
+
+layout(constant_id = 0) const int BLOCK_SIZE = 32;
+// gqa_ratio is in the range [1,8]
+layout(constant_id = 1) const uint gqa_ratio = 1;
+
 layout (push_constant) uniform parameter
 {
     uint ncols_x;
@@ -22,52 +31,124 @@ layout (push_constant) uniform parameter
     uint d_offset;
 } p;
 
-shared FLOAT_TYPE tmp[BLOCK_SIZE];
+#if !USE_SUBGROUP_ADD
+shared FLOAT_TYPE tmp[8][BLOCK_SIZE];
+#endif
 
 void main() {
     const uint tid = gl_LocalInvocationID.x;
     const uint row_x = gl_GlobalInvocationID.y;
-    const uint channel = gl_GlobalInvocationID.z;
-    const uint channel_x = channel / (p.nchannels_y / p.nchannels_x);
+
+    uint channel, channel_x;
+
+    // When gqa_ratio > 1, each invocation does multiple rows.
+    // The row in the A matrix is starting from channel / gqa_ratio and the
+    // rows in the B matrix are [channel, channel+gqa_ratio).
+    // When gpa_ratio is 1, each invocation does one row.
+    if (gqa_ratio > 1) {
+        channel_x = gl_GlobalInvocationID.z;
+        channel = channel_x * gqa_ratio;
+    } else {
+        channel = gl_GlobalInvocationID.z;
+        channel_x = channel / (p.nchannels_y / p.nchannels_x);;
+    }
 
     const uint nrows_y = p.ncols_x;
     const uint nrows_dst = p.nrows_x;
     const uint row_dst = row_x;
 
-    tmp[tid] = FLOAT_TYPE(0.0f);
+    FLOAT_TYPE temp[8];
+    [[unroll]] for (uint i = 0; i < 8; ++i) {
+        temp[i] = FLOAT_TYPE(0.0f);
+    }
+
+    // Detect alignment for vector loads
+    bool is_aligned = (p.ncols_x % 4) == 0 && (p.nchannels_x % 4) == 0 && (nrows_y % 4) == 0;
 
     for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
-        const uint col_x = col_x0 + tid;
 
-        if (col_x >= p.ncols_x) {
-            break;
-        }
+        // Use vec4 loads if aligned
+        if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) {
 
-        // x is transposed and permuted
-        const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
-        const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
+            uint col_x = col_x0 + 4*tid;
+            const uint row_y = col_x;
 
-        const uint row_y = col_x;
+            // x is transposed and permuted
+            const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
+            const vec4 av4 = vec4(data_a_v4[ix / 4]);
 
-        // y is not transposed but permuted
-        const uint iy = channel*nrows_y + row_y;
+            [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
+                // y is not transposed but permuted
+                const uint iy = (channel + c)*nrows_y + row_y;
 
-        tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]);
-    }
+                vec4 bv4 = data_b_v4[iy / 4];
+                temp[c] += dot(av4, bv4);
+            }
+
+            col_x0 += 3*BLOCK_SIZE;
+        } else {
+            const uint col_x = col_x0 + tid;
+
+            if (col_x >= p.ncols_x) {
+                break;
+            }
 
-    // dst is not transposed and not permuted
-    const uint idst = channel*nrows_dst + row_dst;
+            // x is transposed and permuted
+            const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
+            const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
 
+            const uint row_y = col_x;
+
+            [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
+                // y is not transposed but permuted
+                const uint iy = (channel + c)*nrows_y + row_y;
+
+                temp[c] = fma(xi, FLOAT_TYPE(data_b[iy]), temp[c]);
+            }
+        }
+    }
+
+#if USE_SUBGROUP_ADD
+    // reduce vec4 at a time
+    vec4 t = vec4(temp[0], temp[1], temp[2], temp[3]);
+    t = subgroupAdd(t);
+    temp[0] = t[0];
+    temp[1] = t[1];
+    temp[2] = t[2];
+    temp[3] = t[3];
+    if (gqa_ratio > 4) {
+        t = vec4(temp[4], temp[5], temp[6], temp[7]);
+        t = subgroupAdd(t);
+        temp[4] = t[0];
+        temp[5] = t[1];
+        temp[6] = t[2];
+        temp[7] = t[3];
+    }
+#else
+    [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
+        tmp[c][tid] = temp[c];
+    }
     // sum up partial sums and write back result
     barrier();
     [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
         if (tid < s) {
-            tmp[tid] += tmp[tid + s];
+            [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
+                temp[c] += tmp[c][tid + s];
+                tmp[c][tid] = temp[c];
+            }
         }
         barrier();
     }
+    [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
+        temp[c] = tmp[c][tid];
+    }
+#endif
 
     if (tid == 0) {
-        dst[idst] = tmp[0];
+        [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
+            // dst is not transposed and not permuted
+            const uint idst = (channel + c)*nrows_dst + row_dst;
+            dst[idst] = temp[c];
+        }
     }
 }
index 519e610e31dc60350d048034478f68fe5f22b39f..1edb8267f1ebef7bf671dda9f82a63f0cd10aa0e 100644 (file)
@@ -426,8 +426,9 @@ void process_shaders() {
         }
     }
 
-    string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
-    string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
+    string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
+    string_to_spv("mul_mat_vec_p021_f16_f32",              "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
+    string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
 
     // Norms
     string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
index d48cd2172315597be4e29e8e871ee01a6a8313f3..9d7847d21b27f63a2c080883f5aed18badc0e82a 100644 (file)
@@ -1964,9 +1964,10 @@ struct test_mul_mat : public test_case {
     const std::array<int64_t, 2> bs;  // dims 3 and 4
     const std::array<int64_t, 2> nr;  // repeat in dims 3 and 4
     const std::array<int64_t, 4> per; // permutation of dimensions
+    const bool v; // whether a is a non-contiguous view
 
     std::string vars() override {
-        return VARS_TO_STR8(type_a, type_b, m, n, k, bs, nr, per);
+        return VARS_TO_STR9(type_a, type_b, m, n, k, bs, nr, per, v);
     }
 
     double max_nmse_err() override {
@@ -1986,8 +1987,9 @@ struct test_mul_mat : public test_case {
             int64_t m = 32, int64_t n = 32, int64_t k = 32,
             std::array<int64_t, 2> bs = {10, 10},
             std::array<int64_t, 2> nr = {2, 2},
-            std::array<int64_t, 4> per = {0, 1, 2, 3})
-        : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per) {}
+            std::array<int64_t, 4> per = {0, 1, 2, 3},
+            bool v = false)
+        : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per), v(v) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         // C^T = A * B^T: (k, m) * (k, n) => (m, n)
@@ -1997,6 +1999,7 @@ struct test_mul_mat : public test_case {
         const int npermuted = (per[0] != 0) + (per[1] != 1) + (per[2] != 2) + (per[3] != 3);
         if (npermuted > 0) {
             GGML_ASSERT(npermuted == 2);
+            GGML_ASSERT(!v); // not handled
             GGML_ASSERT(!ggml_is_quantized(type_a) || per[0] == 0);
             GGML_ASSERT(!ggml_is_quantized(type_b) || per[0] == 0);
 
@@ -2020,7 +2023,13 @@ struct test_mul_mat : public test_case {
             ggml_set_name(a, "a_permuted");
             ggml_set_name(b, "b_permuted");
         } else {
-            a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0],       bs[1]);
+
+            if (v) {
+                a = ggml_new_tensor_4d(ctx, type_a, k*2, m, bs[0], bs[1]);
+                a = ggml_view_4d(ctx, a, k, m, bs[0], bs[1], a->nb[1], a->nb[2], a->nb[3], 0);
+            } else {
+                a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0],       bs[1]);
+            }
             b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
             if (!ggml_is_quantized(type_a)) {
                 if (bs[1] == 1 && nr[1] == 1) {
@@ -4176,6 +4185,17 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  64, 45, 128, { 8,  1}, {4, 1}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45,  64, { 8,  1}, {4, 1}));
 
+    for (auto bs : {1,2,4,8}) {
+        for (auto nr : {1,4}) {
+            for (uint32_t m = 0; m < 2; ++m) {
+                for (uint32_t k = 0; k < 2; ++k) {
+                    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056 + m, 1, 128 + k,  {bs,  1}, {nr, 1}, {0, 2, 1, 3}));
+                    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128 + m,  1, 1056 + k, {bs,  1}, {nr, 1}, {0, 1, 2, 3}, true));
+                }
+            }
+        }
+    }
+
     // sycl backend will limit task global_range < MAX_INT
     // test case for f16-type-convert-to-fp32 kernel with large k under fp32 compute dtype (occurs in stable-diffusion)
     // however this case needs to alloc more memory which may fail in some devices (Intel Arc770, etc.)
@@ -4444,6 +4464,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
     test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
     test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1}));
 
+    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8,  1}, {4, 1}, {0, 2, 1, 3}));
+    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8,  1}, {4, 1}, {0, 1, 2, 3}, true));
+
     for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
         for (ggml_type type_a : all_types) {
             for (ggml_type type_b : {GGML_TYPE_F32}) {