]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : support bcast ggml_soft_max_ext, ggml_flash_attn_ext (#14435)
authorGeorgi Gerganov <redacted>
Fri, 27 Jun 2025 18:50:57 +0000 (21:50 +0300)
committerGeorgi Gerganov <redacted>
Wed, 2 Jul 2025 12:48:33 +0000 (15:48 +0300)
ggml-ci

ggml/include/ggml.h
ggml/src/ggml-cann/ggml-cann.cpp
ggml/src/ggml-cpu/ops.cpp
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-metal/ggml-metal-impl.h
ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-metal/ggml-metal.metal
ggml/src/ggml-sycl/ggml-sycl.cpp
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml.c
tests/test-backend-ops.cpp

index ec5478db8c8df04537baf04e35f8df5526238d35..a92495dbb0070330ddec35300ceca0ea161a886f 100644 (file)
@@ -1510,8 +1510,14 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    // a    [ne0, ne01, ne02, ne03]
+    // mask [ne0, ne11, ne12, ne13] | ne11 >= ne01, F16 or F32, optional
+    //
+    // broadcast:
+    //   ne02 % ne12 == 0
+    //   ne03 % ne13 == 0
+    //
     // fused soft_max(a*scale + mask*(ALiBi slope))
-    // mask is optional
     // max_bias = 0.0f for no ALiBi
     GGML_API struct ggml_tensor * ggml_soft_max_ext(
             struct ggml_context * ctx,
@@ -1974,11 +1980,16 @@ extern "C" {
 
 #define GGML_KQ_MASK_PAD 64
 
-    // q:    [n_embd_k, n_batch,     n_head,    1]
-    // k:    [n_embd_k, n_kv,        n_head_kv, 1]
-    // v:    [n_embd_v, n_kv,        n_head_kv, 1] !! not transposed !!
-    // mask: [n_kv,     n_batch_pad, 1,         1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
-    // res:  [n_embd_v, n_head,      n_batch,   1] !! permuted !!
+    // q:    [n_embd_k, n_batch,     n_head,    ne3]
+    // k:    [n_embd_k, n_kv,        n_head_kv, ne3]
+    // v:    [n_embd_v, n_kv,        n_head_kv, ne3] !! not transposed !!
+    // mask: [n_kv,     n_batch_pad, ne32,      1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
+    // res:  [n_embd_v, n_head,      n_batch,   ne3] !! permuted !!
+    //
+    // broadcast:
+    //   n_head % n_head_kv == 0
+    //   ne3    % ne32      == 0
+    //
     GGML_API struct ggml_tensor * ggml_flash_attn_ext(
             struct ggml_context * ctx,
             struct ggml_tensor  * q,
index d1a0ad374d691dc033d6958a56561d49550aac0d..8a3d2026ed898c459b78d4d9afb6852a7ef4203a 100755 (executable)
@@ -2187,7 +2187,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
         case GGML_OP_SQRT:
         case GGML_OP_CLAMP:
         case GGML_OP_DIAG_MASK_INF:
-        case GGML_OP_SOFT_MAX:
         case GGML_OP_SUM_ROWS:
         case GGML_OP_ARGSORT:
         case GGML_OP_ACC:
@@ -2205,6 +2204,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
         case GGML_OP_PAD_REFLECT_1D:
         case GGML_OP_COUNT_EQUAL:
             return true;
+        case GGML_OP_SOFT_MAX:
+            // TODO: support broadcast
+            // ref: https://github.com/ggml-org/llama.cpp/pull/14435
+            return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
         case GGML_OP_FLASH_ATTN_EXT:{
             // derived from [ggml-cuda.cu]
             if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
@@ -2227,6 +2230,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
                 // DeepSeek MLA
                 return false;
             }
+            // TODO: support broadcast
+            // ref: https://github.com/ggml-org/llama.cpp/pull/14435
             if (op->src[0]->ne[3] != 1) {
                 return false;
             }
index dd83efde7141a4fab1a23c809bec7e108f2eb658..2ae0721e9d171863529e2f1cf45e07342125945e 100644 (file)
@@ -5232,14 +5232,17 @@ static void ggml_compute_forward_soft_max_f32(
     memcpy(&scale,    (float *) dst->op_params + 0, sizeof(float));
     memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
 
-    // TODO: handle transposed/permuted matrices
-
     const int ith = params->ith;
     const int nth = params->nth;
 
     GGML_TENSOR_UNARY_OP_LOCALS
 
-    //const int64_t ne11 = src1 ? src1->ne[1] : 1;
+    const int64_t nb11 = src1 ? src1->nb[1] : 1;
+    const int64_t nb12 = src1 ? src1->nb[2] : 1;
+    const int64_t nb13 = src1 ? src1->nb[3] : 1;
+
+    const int64_t ne12 = src1 ? src1->ne[2] : 1;
+    const int64_t ne13 = src1 ? src1->ne[3] : 1;
 
     // TODO: is this supposed to be ceil instead of floor?
     //       https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@@ -5249,68 +5252,66 @@ static void ggml_compute_forward_soft_max_f32(
     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
-    const int nc = src0->ne[0];
-    const int nr = ggml_nrows(src0);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
+    float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
 
     const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
 
-    for (int i1 = ir0; i1 < ir1; i1++) {
-        // ALiBi
-        const uint32_t h = (i1/ne01)%ne02; // head
-        const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
-
-        float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
-        float * dp = (float *)((char *)  dst->data +  i1*dst->nb[1]);
-
-        // broadcast the mask across rows
-        ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
-        float       * mp_f32 = src1 ? (float       *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
-
-        ggml_vec_cpy_f32  (nc, wp, sp);
-        ggml_vec_scale_f32(nc, wp, scale);
-        if (mp_f32) {
-            if (use_f16) {
-                for (int i = 0; i < nc; ++i) {
-                    wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
-                }
-            } else {
-                for (int i = 0; i < nc; ++i) {
-                    wp[i] += slope*mp_f32[i];
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
+                const int64_t i11 = i01;
+                const int64_t i12 = i02%ne12;
+                const int64_t i13 = i03%ne13;
+
+                // ALiBi
+                const uint32_t h = i02; // head
+                const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
+
+                float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+                float * dp = (float *)((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3);
+
+                // broadcast the mask across rows
+                ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
+                float       * mp_f32 = src1 ? (float       *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
+
+                ggml_vec_cpy_f32  (ne00, wp, sp);
+                ggml_vec_scale_f32(ne00, wp, scale);
+                if (mp_f32) {
+                    if (use_f16) {
+                        for (int i = 0; i < ne00; ++i) {
+                            wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
+                        }
+                    } else {
+                        for (int i = 0; i < ne00; ++i) {
+                            wp[i] += slope*mp_f32[i];
+                        }
+                    }
                 }
-            }
-        }
 
 #ifndef NDEBUG
-        for (int i = 0; i < nc; ++i) {
-            //printf("p[%d] = %f\n", i, p[i]);
-            assert(!isnan(wp[i]));
-        }
+                for (int i = 0; i < ne00; ++i) {
+                    //printf("p[%d] = %f\n", i, p[i]);
+                    assert(!isnan(wp[i]));
+                }
 #endif
 
-        float max = -INFINITY;
-        ggml_vec_max_f32(nc, &max, wp);
+                float max = -INFINITY;
+                ggml_vec_max_f32(ne00, &max, wp);
 
-        ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
-        assert(sum > 0.0);
+                ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
+                assert(sum > 0.0);
 
-        sum = 1.0/sum;
-        ggml_vec_scale_f32(nc, dp, sum);
+                sum = 1.0/sum;
+                ggml_vec_scale_f32(ne00, dp, sum);
 
 #ifndef NDEBUG
-        for (int i = 0; i < nc; ++i) {
-            assert(!isnan(dp[i]));
-            assert(!isinf(dp[i]));
-        }
+                for (int i = 0; i < ne00; ++i) {
+                    assert(!isnan(dp[i]));
+                    assert(!isinf(dp[i]));
+                }
 #endif
+            }
+        }
     }
 }
 
@@ -7766,7 +7767,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
-    ggml_type    const k_vec_dot_type      = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
+    ggml_type         const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
     ggml_from_float_t const q_to_vec_dot   = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
     ggml_vec_dot_t    const kq_vec_dot     = ggml_get_type_traits_cpu(k->type)->vec_dot;
     ggml_to_float_t   const v_to_float     = ggml_get_type_traits(v->type)->to_float;
@@ -7798,7 +7799,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
             memset(VKQ32, 0, DV*sizeof(float));
         }
 
-        const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
+        const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq3%mask->ne[2])*mask->nb[2]) : NULL;
 
         // k indices
         const int ik3 = iq3 / rk3;
index 086f9a56c4acaf392aabf1835f212a6d355430c8..186492be3402787d49eba3e0044c81856071555e 100644 (file)
@@ -3327,8 +3327,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_CONT:
             return op->src[0]->type != GGML_TYPE_BF16;
         case GGML_OP_DIAG_MASK_INF:
-        case GGML_OP_SOFT_MAX:
             return true;
+        case GGML_OP_SOFT_MAX:
+            // TODO: support batching
+            if (op->src[0]->ne[3] != 1) {
+                return false;
+            }
+            // TODO: support broadcast
+            // ref: https://github.com/ggml-org/llama.cpp/pull/14435
+            return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
         case GGML_OP_SOFT_MAX_BACK: {
             float max_bias = 0.0f;
             memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
@@ -3375,6 +3382,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             if (op->src[0]->ne[0] == 192) {
                 return false;
             }
+            // TODO: support broadcast
+            // ref: https://github.com/ggml-org/llama.cpp/pull/14435
             if (op->src[0]->ne[3] != 1) {
                 return false;
             }
index 7a9aab31684e18522859c70e03cbfa7c4412a150..8c2a971927414de104d797e41eb35a927f953bd8 100644 (file)
@@ -229,7 +229,9 @@ typedef struct {
     uint64_t nb21;
     uint64_t nb22;
     uint64_t nb23;
+    int32_t  ne32;
     uint64_t nb31;
+    uint64_t nb32;
     int32_t  ne1;
     int32_t  ne2;
     float    scale;
@@ -461,9 +463,21 @@ typedef struct {
 } ggml_metal_kargs_sum_rows;
 
 typedef struct {
-    int64_t  ne00;
-    int64_t  ne01;
-    int64_t  ne02;
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne11;
+    int32_t  ne12;
+    int32_t  ne13;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
     float    scale;
     float    max_bias;
     float    m0;
index 12a366957891c63e43a8ebd9d2191c33b6ecd828..1a0968ed4d23173fcb635653bd41c910ab2d3af8 100644 (file)
@@ -1725,7 +1725,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
         case GGML_OP_MEAN:
         case GGML_OP_SOFT_MAX:
         case GGML_OP_GROUP_NORM:
-            return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
+            return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
         case GGML_OP_RMS_NORM:
         case GGML_OP_L2_NORM:
             return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
@@ -2644,10 +2644,7 @@ static bool ggml_metal_encode_node(
                 memcpy(&scale,    ((const int32_t *) dst->op_params) + 0, sizeof(scale));
                 memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
 
-                const int64_t nrows_x = ggml_nrows(src0);
-                const int64_t nrows_y = src0->ne[1];
-
-                const uint32_t n_head      = nrows_x/nrows_y;
+                const uint32_t n_head      = src0->ne[2];
                 const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
 
                 const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
@@ -2707,6 +2704,18 @@ static bool ggml_metal_encode_node(
                     /*.ne00        =*/ ne00,
                     /*.ne01        =*/ ne01,
                     /*.ne02        =*/ ne02,
+                    /*.nb01        =*/ nb01,
+                    /*.nb02        =*/ nb02,
+                    /*.nb03        =*/ nb03,
+                    /*.ne11        =*/ ne11,
+                    /*.ne12        =*/ ne12,
+                    /*.ne13        =*/ ne13,
+                    /*.nb11        =*/ nb11,
+                    /*.nb12        =*/ nb12,
+                    /*.nb13        =*/ nb13,
+                    /*.nb1         =*/ nb1,
+                    /*.nb2         =*/ nb2,
+                    /*.nb3         =*/ nb3,
                     /*.scale       =*/ scale,
                     /*.max_bias    =*/ max_bias,
                     /*.m0          =*/ m0,
@@ -2726,7 +2735,7 @@ static bool ggml_metal_encode_node(
 
                 [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
-                [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
             } break;
         case GGML_OP_DIAG_MASK_INF:
             {
@@ -4979,7 +4988,9 @@ static bool ggml_metal_encode_node(
                     /*.nb21          =*/ nb21,
                     /*.nb22          =*/ nb22,
                     /*.nb23          =*/ nb23,
+                    /*.ne32          =*/ ne32,
                     /*.nb31          =*/ nb31,
+                    /*.nb32          =*/ nb32,
                     /*.ne1           =*/ ne1,
                     /*.ne2           =*/ ne2,
                     /*.scale         =*/ scale,
index dac45c7a99b52b326963efc549f23e425ae3a776..dfff669726b94ea9a4ec7d784a250c200f27f22a 100644 (file)
@@ -1320,24 +1320,28 @@ kernel void kernel_soft_max(
         device        char * dst,
         constant ggml_metal_kargs_soft_max & args,
         threadgroup  float * buf [[threadgroup(0)]],
-        uint  tgpig[[threadgroup_position_in_grid]],
-        uint  tpitg[[thread_position_in_threadgroup]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]],
         uint  tiisg[[thread_index_in_simdgroup]],
-        uint    ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
-    const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
-    const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
+        uint3  tptg[[threads_per_threadgroup]]) {
+    const int32_t i03 = tgpig.z;
+    const int32_t i02 = tgpig.y;
+    const int32_t i01 = tgpig.x;
+
+    const int32_t i13 = i03%args.ne13;
+    const int32_t i12 = i02%args.ne12;
+    const int32_t i11 = i01;
 
-    device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
-    device const     T * pmask = src1 != src0 ? (device const    T *) src1         + i01*args.ne00 : nullptr;
-    device       float * pdst  = (device       float *) dst  + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
+    device const float * psrc0 =                (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
+    device const     T * pmask = src1 != src0 ? (device const T *    ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
+    device       float * pdst  =                (device       float *) (dst  + i01*args.nb1  + i02*args.nb2  + i03*args.nb3);
 
     float slope = 1.0f;
 
     // ALiBi
     if (args.max_bias > 0.0f) {
-        const int64_t h = i02;
+        const int32_t h = i02;
 
         const float base = h < args.n_head_log2 ? args.m0 : args.m1;
         const int   exp  = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
@@ -1348,13 +1352,13 @@ kernel void kernel_soft_max(
     // parallel max
     float lmax = -INFINITY;
 
-    for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
+    for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
         lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
     }
 
     // find the max value in the block
     float max_val = simd_max(lmax);
-    if (ntg > N_SIMDWIDTH) {
+    if (tptg.x > N_SIMDWIDTH) {
         if (sgitg == 0) {
             buf[tiisg] = -INFINITY;
         }
@@ -1373,7 +1377,7 @@ kernel void kernel_soft_max(
 
     // parallel sum
     float lsum = 0.0f;
-    for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
+    for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
         const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
         lsum += exp_psrc0;
         pdst[i00] = exp_psrc0;
@@ -1385,7 +1389,7 @@ kernel void kernel_soft_max(
 
     float sum = simd_sum(lsum);
 
-    if (ntg > N_SIMDWIDTH) {
+    if (tptg.x > N_SIMDWIDTH) {
         if (sgitg == 0) {
             buf[tiisg] = 0.0f;
         }
@@ -1404,7 +1408,7 @@ kernel void kernel_soft_max(
 
     const float inv_sum = 1.0f/sum;
 
-    for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
+    for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
         pdst[i00] *= inv_sum;
     }
 }
@@ -1416,23 +1420,27 @@ kernel void kernel_soft_max_4(
         device        char * dst,
         constant ggml_metal_kargs_soft_max & args,
         threadgroup  float * buf [[threadgroup(0)]],
-        uint  tgpig[[threadgroup_position_in_grid]],
-        uint  tpitg[[thread_position_in_threadgroup]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]],
         uint  tiisg[[thread_index_in_simdgroup]],
-        uint    ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
-    const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
-    const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
+        uint3  tptg[[threads_per_threadgroup]]) {
+    const int32_t i03 = tgpig.z;
+    const int32_t i02 = tgpig.y;
+    const int32_t i01 = tgpig.x;
+
+    const int32_t i13 = i03%args.ne13;
+    const int32_t i12 = i02%args.ne12;
+    const int32_t i11 = i01;
 
-    device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
-    device const      T * pmask = src1 != src0 ? (device const     T *) src1         + i01*args.ne00/4 : nullptr;
-    device       float4 * pdst4 = (device       float4 *) dst  + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
+    device const float4 * psrc4 =                (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
+    device const      T * pmask = src1 != src0 ? (device const T *     ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
+    device       float4 * pdst4 =                (device       float4 *) (dst  + i01*args.nb1  + i02*args.nb2  + i03*args.nb3);
 
     float slope = 1.0f;
 
     if (args.max_bias > 0.0f) {
-        const int64_t h = i02;
+        const int32_t h = i02;
 
         const float base = h < args.n_head_log2 ? args.m0 : args.m1;
         const int   exp  = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
@@ -1443,14 +1451,14 @@ kernel void kernel_soft_max_4(
     // parallel max
     float4 lmax4 = -INFINITY;
 
-    for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
+    for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
         lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
     }
 
     const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
 
     float max_val = simd_max(lmax);
-    if (ntg > N_SIMDWIDTH) {
+    if (tptg.x > N_SIMDWIDTH) {
         if (sgitg == 0) {
             buf[tiisg] = -INFINITY;
         }
@@ -1469,7 +1477,7 @@ kernel void kernel_soft_max_4(
 
     // parallel sum
     float4 lsum4 = 0.0f;
-    for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
+    for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
         const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
         lsum4 += exp_psrc4;
         pdst4[i00] = exp_psrc4;
@@ -1483,7 +1491,7 @@ kernel void kernel_soft_max_4(
 
     float sum = simd_sum(lsum);
 
-    if (ntg > N_SIMDWIDTH) {
+    if (tptg.x > N_SIMDWIDTH) {
         if (sgitg == 0) {
             buf[tiisg] = 0.0f;
         }
@@ -1502,7 +1510,7 @@ kernel void kernel_soft_max_4(
 
     const float inv_sum = 1.0f/sum;
 
-    for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
+    for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
         pdst4[i00] *= inv_sum;
     }
 }
@@ -3776,7 +3784,7 @@ kernel void kernel_flash_attn_ext(
                 // load the mask in shared memory
                 #pragma unroll(Q)
                 for (short j = 0; j < Q; ++j) {
-                    device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
+                    device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq3%args.ne32)*args.nb32);
 
                     const float m = pm[ic + tiisg];
 
@@ -4262,7 +4270,7 @@ kernel void kernel_flash_attn_ext_vec(
         const bool has_mask = mask != q;
 
         // pointer to the mask
-        device const half * pm = (device const half *) (mask + iq1*args.nb31);
+        device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq3%args.ne32)*args.nb32);
 
         float slope = 1.0f;
 
index ae5e062572e324a8c54eab4abe482742d0d6fc0c..1d41d7a4be8206182bca3c3bbb3f86d48badf644 100644 (file)
@@ -4395,9 +4395,15 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
             return true;
         case GGML_OP_CONT:
             return op->src[0]->type != GGML_TYPE_BF16;
-        case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_SOFT_MAX:
-            return true;
+            // TODO: support batching
+            if (op->src[0]->ne[3] != 1) {
+                return false;
+            }
+            // TODO: support broadcast
+            // ref: https://github.com/ggml-org/llama.cpp/pull/14435
+            return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
+        case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_ROPE:
         case GGML_OP_IM2COL:
             return true;
index 7c11890d9b44dfee95868998d727bd51efb54f3c..b8e25ba20980a5cbaadb387b1e74e10febbe5e39 100644 (file)
@@ -10248,6 +10248,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
                     return false;
                 }
+                // TODO: support broadcast
+                // ref: https://github.com/ggml-org/llama.cpp/pull/14435
+                if (op->src[0]->ne[3] != 1) {
+                    return false;
+                }
                 // It's straightforward to support different K/V dequant, but would
                 // significantly increase the number of pipelines
                 if (op->src[1]->type != op->src[2]->type) {
@@ -10406,7 +10411,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_SCALE:
         case GGML_OP_PAD:
         case GGML_OP_DIAG_MASK_INF:
+            return true;
         case GGML_OP_SOFT_MAX:
+            // TODO: support batching
+            if (op->src[0]->ne[3] != 1) {
+                return false;
+            }
+            // TODO: support broadcast
+            // ref: https://github.com/ggml-org/llama.cpp/pull/14435
+            return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
         case GGML_OP_SOFT_MAX_BACK:
         case GGML_OP_ARGSORT:
         case GGML_OP_SUM:
index 4227fb101f09a685a0800a69c22f732284f1a887..f1245c50c9299013d08d5213a470b398b470ef9f 100644 (file)
@@ -3666,9 +3666,11 @@ static struct ggml_tensor * ggml_soft_max_impl(
     if (mask) {
         GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
         GGML_ASSERT(ggml_is_contiguous(mask));
-        GGML_ASSERT(ggml_is_matrix(mask));
+        GGML_ASSERT(ggml_is_3d(mask));
         GGML_ASSERT(mask->ne[0] == a->ne[0]);
         GGML_ASSERT(mask->ne[1] >= a->ne[1]);
+        GGML_ASSERT(a->ne[2]%mask->ne[2] == 0);
+        GGML_ASSERT(a->ne[3]%mask->ne[3] == 0);
     }
 
     if (max_bias > 0.0f) {
@@ -4689,13 +4691,17 @@ struct ggml_tensor * ggml_flash_attn_ext(
     GGML_ASSERT(ggml_can_mul_mat(k, q));
     // TODO: check if vT can be multiplied by (k*qT)
 
+    GGML_ASSERT(q->ne[3] == k->ne[3]);
+    GGML_ASSERT(q->ne[3] == v->ne[3]);
+
     if (mask) {
         GGML_ASSERT(ggml_is_contiguous(mask));
-        GGML_ASSERT(mask->ne[2] == 1);
-        GGML_ASSERT(mask->ne[3] == 1);
+        GGML_ASSERT(mask->ne[2] == q->ne[3]);
         GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
                 "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
         //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
+
+        GGML_ASSERT(q->ne[3] % mask->ne[2] == 0);
     }
 
     if (max_bias > 0.0f) {
index 2b502e8b1ff8c6182aaa09fc1377f974dbb4c3ee..a4f94c2594038edd5863ca5b2750c76c0dc01702 100644 (file)
@@ -2685,11 +2685,12 @@ struct test_soft_max : public test_case {
     const std::array<int64_t, 4> ne;
     const bool mask;
     const ggml_type m_prec;
+    const std::array<int64_t, 2> nr23; // broadcast only dims 2 and 3
     const float scale;
     const float max_bias;
 
     std::string vars() override {
-        return VARS_TO_STR6(type, ne, mask, m_prec, scale, max_bias);
+        return VARS_TO_STR7(type, ne, mask, m_prec, nr23, scale, max_bias);
     }
 
     // the 1024 test with bias occasionally fails:
@@ -2702,18 +2703,19 @@ struct test_soft_max : public test_case {
             std::array<int64_t, 4> ne = {10, 5, 4, 3},
             bool mask = false,
             ggml_type m_prec = GGML_TYPE_F32,
+            std::array<int64_t, 2> nr23 = {1, 1},
             float scale = 1.0f,
             float max_bias = 0.0f)
-        : type(type), ne(ne), mask(mask), m_prec(m_prec), scale(scale), max_bias(max_bias) {}
+        : type(type), ne(ne), mask(mask), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
-        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]);
         ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * mask = nullptr;
         if (this->mask) {
-            mask = ggml_new_tensor_2d(ctx, m_prec, ne[0], ne[1]);
+            mask = ggml_new_tensor_4d(ctx, m_prec, ne[0], ne[1], ne[2], ne[3]);
             ggml_set_name(mask, "mask");
         }
 
@@ -3544,7 +3546,7 @@ struct test_flash_attn_ext : public test_case {
     const int64_t hsk; // K head size
     const int64_t hsv; // V head size
     const int64_t nh; // num heads
-    const int64_t nr; // repeat in Q, tests for grouped-query attention
+    const std::array<int64_t, 2> nr23; // repeat in dim 2 and 3, tests for grouped-query attention
     const int64_t kv; // kv size
     const int64_t nb; // batch size
 
@@ -3558,7 +3560,7 @@ struct test_flash_attn_ext : public test_case {
     std::array<int32_t, 4> permute;
 
     std::string vars() override {
-        return VARS_TO_STR12(hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
+        return VARS_TO_STR12(hsk, hsv, nh, nr23, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
     }
 
     double max_nmse_err() override {
@@ -3569,13 +3571,13 @@ struct test_flash_attn_ext : public test_case {
         GGML_UNUSED(t);
         // Just counting matmul costs:
         // Q*K^T is nb x hsk x kv, P*V is nb x kv x hsv, per head
-        return 2 * nh*nr * nb * (hsk + hsv) * kv;
+        return (2 * nh*nr23[0] * nb * (hsk + hsv) * kv)*nr23[1];
     }
 
-    test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
+    test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, std::array<int64_t, 2> nr23 = {1, 1}, int64_t kv = 96, int64_t nb = 8,
                         bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
                         ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3})
-        : hsk(hsk), hsv(hsv), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
+        : hsk(hsk), hsv(hsv), nh(nh), nr23(nr23), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV));
@@ -3594,18 +3596,18 @@ struct test_flash_attn_ext : public test_case {
             return t;
         };
 
-        ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr, 1);
+        ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0], nr23[1]);
         ggml_set_name(q, "q");
 
-        ggml_tensor * k = create_permuted(type_KV,       hsk_padded, kv, nh,    1);
+        ggml_tensor * k = create_permuted(type_KV,       hsk_padded, kv, nh,         nr23[1]);
         ggml_set_name(k, "k");
 
-        ggml_tensor * v = create_permuted(type_KV,       hsv_padded, kv, nh,    1);
+        ggml_tensor * v = create_permuted(type_KV,       hsv_padded, kv, nh,         nr23[1]);
         ggml_set_name(v, "v");
 
         ggml_tensor * m = nullptr;
         if (mask) {
-            m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1);
+            m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[1], 1);
             ggml_set_name(m, "m");
         }
 
@@ -4714,26 +4716,31 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
                     for (int64_t ne1 : {16, 1024}) {
                         if (mask) {
                             for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {
-                                test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, mask, m_prec, scale, max_bias));
-                                test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, scale, max_bias));
+                                test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, mask, m_prec, {1, 1}, scale, max_bias));
+                                test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias));
+
+                                if (ne0 <= 32 && ne1 <= 32) {
+                                    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, mask, m_prec, {3, 1}, scale, max_bias));
+                                    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {2, 3}, scale, max_bias));
+                                }
                             }
                         } else {
                             /* The precision of mask here doesn't matter as boolean mask is false */
-                            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, mask, GGML_TYPE_F32, scale, max_bias));
-                            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, GGML_TYPE_F32, scale, max_bias));
+                            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, mask, GGML_TYPE_F32, {1, 1}, scale, max_bias));
+                            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, GGML_TYPE_F32, {1, 1}, scale, max_bias));
                         }
                     }
                 }
             }
         }
     }
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F32,  0.1f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F16,  0.1f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, GGML_TYPE_F32, 0.1f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32,  0.1f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16,  0.1f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32,  0.1f, 8.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16,  0.1f, 8.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true,  GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true,  GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));
 
     for (float max_bias : {0.0f, 8.0f}) {
         for (float scale : {1.0f, 0.1f}) {
@@ -4833,20 +4840,23 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
                     for (float logit_softcap : {0.0f, 10.0f}) {
                         if (hsk != 128 && logit_softcap != 0.0f) continue;
                         for (int nh : { 4, }) {
-                            for (int nr : { 1, 4, 16 }) {
-                                if (nr == 16 && hsk != 128) continue;
-                                for (int kv : { 512, 1024, }) {
-                                    if (nr != 1 && kv != 512) continue;
-                                    for (int nb : { 1, 3, 32, 35, }) {
-                                        for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
-                                            if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
-                                            for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
-                                                test_cases.emplace_back(new test_flash_attn_ext(
-                                                    hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
-                                                // run fewer test cases permuted
-                                                if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
+                            for (int nr3 : { 1, 3, }) {
+                                if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes
+                                for (int nr2 : { 1, 4, 16 }) {
+                                    if (nr2 == 16 && hsk != 128) continue;
+                                    for (int kv : { 512, 1024, }) {
+                                        if (nr2 != 1 && kv != 512) continue;
+                                        for (int nb : { 1, 3, 32, 35, }) {
+                                            for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
+                                                if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
+                                                for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
                                                     test_cases.emplace_back(new test_flash_attn_ext(
-                                                        hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
+                                                                hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
+                                                    // run fewer test cases permuted
+                                                    if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
+                                                        test_cases.emplace_back(new test_flash_attn_ext(
+                                                                    hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
+                                                    }
                                                 }
                                             }
                                         }
@@ -4890,13 +4900,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
     test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
     test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
 
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
 
     test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1}));
     test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
@@ -4928,7 +4938,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
     for (int kv : { 4096, 8192, 16384, }) {
         for (int hs : { 64, 128, }) {
             for (int nr : { 1, 4, }) {
-                test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, nr, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
+                test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, {nr, 1}, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
             }
         }
     }