]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CANN: Add the basic supports of Flash Attention kernel (llama/13627)
authorBizhao Shi <redacted>
Mon, 26 May 2025 02:20:18 +0000 (10:20 +0800)
committerGeorgi Gerganov <redacted>
Tue, 27 May 2025 15:03:00 +0000 (18:03 +0300)
* cann: add the basic FA support

* cann: update the readme

* cann: update the FlashAttention with PSEShift

* cann: update the input parameters in FA

* cann: update the alibi with max_bias

* cann: add the constrints of softcap

* cann: update the docs CANN.md

* cann: update the docs CANN.md

* cann: fix typo of CANN.md

* cann: add some comments and update the CANN.md

* cann: update the CANN.md

* cann: update the inner precise for fusedInferAttention

* cann: update the constraints of flash_attn_ext on ggml-cann.cpp

* cann: clean the whitespace

* cann: clean the whitespace

* cann: add a new endline

ggml/src/ggml-cann/CMakeLists.txt [changed mode: 0644->0755]
ggml/src/ggml-cann/Doxyfile [changed mode: 0644->0755]
ggml/src/ggml-cann/acl_tensor.cpp [changed mode: 0644->0755]
ggml/src/ggml-cann/acl_tensor.h [changed mode: 0644->0755]
ggml/src/ggml-cann/aclnn_ops.cpp [changed mode: 0644->0755]
ggml/src/ggml-cann/aclnn_ops.h [changed mode: 0644->0755]
ggml/src/ggml-cann/common.h [changed mode: 0644->0755]
ggml/src/ggml-cann/ggml-cann.cpp [changed mode: 0644->0755]

old mode 100644 (file)
new mode 100755 (executable)
old mode 100644 (file)
new mode 100755 (executable)
old mode 100644 (file)
new mode 100755 (executable)
index f5462c5..f311864
@@ -31,6 +31,8 @@ aclDataType ggml_cann_type_mapping(ggml_type type) {
             return ACL_FLOAT;
         case GGML_TYPE_F16:
             return ACL_FLOAT16;
+        case GGML_TYPE_BF16:
+            return ACL_BF16;
         case GGML_TYPE_I8:
             return ACL_INT8;
         case GGML_TYPE_I16:
old mode 100644 (file)
new mode 100755 (executable)
old mode 100644 (file)
new mode 100755 (executable)
index 9c67664..437ece2
@@ -66,6 +66,7 @@
 #include <aclnnop/aclnn_gt_scalar.h>
 #include <aclnnop/aclnn_pow.h>
 #include <aclnnop/aclnn_grouped_matmul_v2.h>
+#include <aclnnop/aclnn_fused_infer_attention_score_v2.h>
 #include <float.h>
 
 #include <cmath>
 #include <vector>
 
 #include "ggml-impl.h"
+#include "ggml.h"
 
 #define GGML_COMMON_DECL_C
 
 #include "../ggml-common.h"
 
+
 void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclTensor ** acl_src0,
                  aclTensor ** acl_src1, aclTensor ** acl_dst) {
     GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_can_repeat(src1, src0));
@@ -2861,3 +2864,330 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
             break;
     }
 }
+
+void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
+
+    ggml_tensor* src0 = dst->src[0]; // q, fp32
+    ggml_tensor* src1 = dst->src[1]; // k, fp16
+    ggml_tensor* src2 = dst->src[2]; // v, fp16
+    ggml_tensor* src3 = dst->src[3]; // mask, fp16
+
+    float maxBias = 0.0f;
+    float scaleValue = 1.0f;
+    float logitSoftcap = 0.0f;
+    memcpy(&scaleValue,    (float*)dst->op_params + 0, sizeof(float));
+    memcpy(&maxBias,       (float*)dst->op_params + 1, sizeof(float));
+    memcpy(&logitSoftcap,  (float*)dst->op_params + 2, sizeof(float));
+
+    if(logitSoftcap == 0.0f){
+        size_t faElemSize = sizeof(uint16_t);
+        auto   faDataType = ACL_FLOAT16; //ACL_BF16;
+
+        aclTensor* acl_src0_f16_tensor = nullptr;
+        aclTensor* acl_src1_f16_tensor = nullptr;
+        aclTensor* acl_src2_f16_tensor = nullptr;
+        aclTensor* acl_dst_f16_tensor  = nullptr;
+
+        // Step 1: cast the src0 (Query) to fp16 if needed
+        ggml_cann_pool_alloc src0_f16_allocator(ctx.pool());
+        void* src0_f16_buffer = nullptr;
+
+        if(ggml_cann_type_mapping(src0->type) != faDataType){
+            aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0);
+            src0_f16_buffer = src0_f16_allocator.alloc(
+                                    ggml_nelements(src0) * faElemSize);
+
+            int64_t* src0_f16_ne = src0->ne;
+            size_t   src0_f16_nb[GGML_MAX_DIMS];
+            src0_f16_nb[0] = sizeof(uint16_t);
+            for(int i = 1; i < GGML_MAX_DIMS; ++i){
+                src0_f16_nb[i] = src0_f16_nb[i - 1] * src0_f16_ne[i - 1];
+            }
+
+            acl_src0_f16_tensor = ggml_cann_create_tensor(
+                src0_f16_buffer, faDataType, faElemSize,
+                src0_f16_ne, src0_f16_nb, GGML_MAX_DIMS
+            );
+            aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType);
+            ggml_cann_release_resources(ctx, acl_src0_f32_tensor);
+        }else{
+            acl_src0_f16_tensor = ggml_cann_create_tensor(src0);
+        }
+
+        // Step 2: create the acl tensors for src1 (Key), src2 (Value),
+        //         and the direct output from FusedInferAttention
+
+        acl_src1_f16_tensor = ggml_cann_create_tensor(src1);
+        acl_src2_f16_tensor = ggml_cann_create_tensor(src2);
+
+        ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
+        void* out_f16_buffer = out_f16_allocator.alloc(
+                                    ggml_nelements(dst) * faElemSize);
+
+        int64_t* out_f16_ne = src0->ne;
+        size_t out_f16_nb[GGML_MAX_DIMS];
+        out_f16_nb[0] = faElemSize;
+        for(int i = 1; i < GGML_MAX_DIMS; ++i){
+            out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
+        }
+
+        acl_dst_f16_tensor = ggml_cann_create_tensor(
+            out_f16_buffer, faDataType, faElemSize,
+            out_f16_ne, out_f16_nb, GGML_MAX_DIMS
+        );
+
+        // Step 3: create the PSEShift tensor if needed
+        //         this tensor is considered as mask (f16) in the llama.cpp
+
+        aclTensor* bcast_pse_tensor = nullptr;
+        int64_t bcast_pse_ne[GGML_MAX_DIMS];
+        size_t bcast_pse_nb[GGML_MAX_DIMS];
+        ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool());
+        void* bcast_pse_buffer = nullptr;
+
+        if(src3 != nullptr){
+            bcast_pse_buffer = bcast_pse_allocator.alloc(
+                            ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t));
+
+            if(src0->ne[1] > 1){
+                // Case 1: broadcast pse for prefill stage with multiple head
+                aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3);
+                bcast_pse_ne[0] = src3->ne[0];
+                bcast_pse_ne[1] = src3->ne[1];
+                bcast_pse_ne[2] = src0->ne[2];
+                bcast_pse_ne[3] = src3->ne[3];
+
+                bcast_pse_nb[0] = sizeof(uint16_t);
+                for(int i = 1; i < GGML_MAX_DIMS; ++i){
+                    bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
+                }
+
+                bcast_pse_tensor = ggml_cann_create_tensor(
+                    bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
+                    bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
+
+                int64_t repeats[] = {1, src0->ne[2], 1, 1};
+                aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats);
+
+                ggml_cann_release_resources(ctx, acl_mask_f16_tensor);
+            }else{
+                // Case 2: trunc the first row and broadcast pse for decode stage with multiple head
+                int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]};
+                size_t* trunc_pse_nb = src3->nb;
+
+                aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(
+                    src3->data, ACL_FLOAT16, sizeof(uint16_t),
+                    trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS);
+
+                bcast_pse_ne[0] = src3->ne[0];
+                bcast_pse_ne[1] = src0->ne[1];
+                bcast_pse_ne[2] = src0->ne[2];
+                bcast_pse_ne[3] = src3->ne[3];
+
+                bcast_pse_nb[0] = sizeof(uint16_t);
+                for(int i = 1; i < GGML_MAX_DIMS; ++i){
+                    bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
+                }
+
+                bcast_pse_tensor = ggml_cann_create_tensor(
+                    bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
+                    bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
+
+                int64_t repeats[] = {1, src0->ne[2], 1, 1};
+                aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats);
+
+                ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
+            }
+
+            // Compute the slope if needed. Derived from ggml_cann_softmax().
+            if(maxBias != 0.0f){
+                // alibi
+                const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3];
+                const int64_t n_head = src0->ne[2];
+                const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));
+                float m0 = powf(2.0f, -(maxBias) / n_heads_log2_floor);
+                float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor);
+                // init arange
+                ggml_cann_pool_alloc arange_allocator(ctx.pool(),
+                                                    ne2_ne3 * faElemSize);
+                void* tmp_arange_buffer = arange_allocator.get();
+
+                // arange1: [1, ..., n_heads_log2_floor+1)
+                float start = 1;
+                float stop = n_heads_log2_floor + 1;
+                float step = 1;
+                int64_t n_elements_arange = n_heads_log2_floor;
+
+                int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
+                size_t tmp_arange1_nb[] = {faElemSize};
+                aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor(
+                    tmp_arange_buffer, faDataType, faElemSize,
+                    tmp_arange1_ne, tmp_arange1_nb,
+                    GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+
+                aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
+
+                aclTensor* tmp_arange2_tensor = nullptr;
+                if (n_heads_log2_floor < ne2_ne3) {
+                    // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
+                    start = 1;
+                    stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1;
+                    step = 2;
+                    n_elements_arange = ne2_ne3 - n_heads_log2_floor;
+                    int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
+                    size_t tmp_arange2_nb[] = {faElemSize};
+
+                    aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor(
+                        (char*)tmp_arange_buffer +
+                            n_heads_log2_floor * faElemSize,
+                        faDataType, faElemSize,
+                        tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+                    aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step,
+                                n_elements_arange);
+                }
+
+                // init mk_base
+                ggml_cann_pool_alloc mk_base_allocator(ctx.pool(),
+                                                    ne2_ne3 * faElemSize);
+                void* tmp_mk_base_buffer = mk_base_allocator.get();
+                int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor};
+                size_t tmp_mk_base1_nb[] = {faElemSize};
+                aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor(
+                    tmp_mk_base_buffer, faDataType, faElemSize,
+                    tmp_mk_base1_ne, tmp_mk_base1_nb,
+                    GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+
+                aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor);
+
+                aclTensor* tmp_mk_base2_tensor = nullptr;
+                if (n_heads_log2_floor < ne2_ne3) {
+                    int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
+                    size_t tmp_mk_base2_nb[] = {faElemSize};
+                    aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor(
+                        (char*)tmp_mk_base_buffer +
+                            n_heads_log2_floor * faElemSize,
+                        faDataType, faElemSize,
+                        tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+                    aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor);
+                }
+
+                // init mk
+                int64_t tmp_mk_base_ne[] = {ne2_ne3};
+                size_t tmp_mk_base_nb[] = {faElemSize};
+                aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor(
+                    tmp_mk_base_buffer, faDataType, faElemSize,
+                    tmp_mk_base_ne, tmp_mk_base_nb,
+                    GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+                aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
+                    tmp_arange_buffer, faDataType, faElemSize,
+                    tmp_mk_base_ne, tmp_mk_base_nb,
+                    GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+                aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor);
+
+                // reshape mk
+                int64_t tmp_mk_ne[] = {1, 1, src0->ne[2], src0->ne[3]};
+                size_t tmp_mk_nb[GGML_MAX_DIMS];
+                tmp_mk_nb[0] = faElemSize;
+                for (int i = 1; i < GGML_MAX_DIMS; i++) {
+                    tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1];
+                }
+                aclTensor* tmp_mk_tensor = ggml_cann_create_tensor(
+                    tmp_mk_base_buffer, faDataType, faElemSize,
+                    tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
+                    ACL_FORMAT_ND);
+                GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor);
+
+                ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor,
+                    tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
+                    tmp_arange_tensor, tmp_mk_tensor);
+            }
+        }
+
+        // Step 4: set the inputs for FusedInferAttention.
+        int kvTensorNum = 1;
+        aclTensor* acl_q_tensor = acl_src0_f16_tensor;
+        aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor};
+        aclTensor* acl_v_tensors[] = {acl_src2_f16_tensor};
+        auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum);
+        auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum);
+
+        int64_t numHeads = src0->ne[2]; // N
+        int64_t numKeyValueHeads = src1->ne[2];
+        // double  scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d)
+        int64_t preTokens = 65535;
+        int64_t nextTokens = 65535;
+        char layout[5] = {'B', 'N', 'S', 'D', 0};
+        int64_t sparseMode = 0;
+        int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2;
+        int64_t blockSize = 0;
+        int64_t antiquantMode = 0;
+        bool softmaxLseFlag = false;
+        int64_t keyAntiquantMode = 0;
+        int64_t valueAntiquantMode = 0;
+
+        // Step 5: launch the FusedInferAttentionScoreV2 kernel.
+        // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md
+
+        GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2,
+            acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
+            bcast_pse_tensor, nullptr, // pse, mask
+            nullptr, nullptr, // actSeqLen, actSeqLenkv
+            nullptr, nullptr, // deqScale1, quantScale1
+            nullptr, nullptr, nullptr, // deqScale2, quantScale2, quantOffset2
+            nullptr, nullptr, // antiquantScale, antiquantOffset
+            nullptr, // blockTable
+            nullptr, nullptr, // qPadSize, kvPadSize
+            nullptr, nullptr, // kAntiquantScale, kAntiQuantOffset
+            nullptr, nullptr, // vAntiquantScale, vAntiQuantOffset
+            nullptr, nullptr, nullptr, // kSharedPrefix, vSharedPrefix, actSharedLen
+            numHeads, scaleValue, // heads, scaleValue
+            preTokens, nextTokens, // preTokens, nextTokens
+            layout, // inputLayout
+            numKeyValueHeads, // numKVHeads
+            sparseMode, innerPrecise, // sparseMode, innerPrecise
+            blockSize, antiquantMode, // blockSize, antiquantMode
+            softmaxLseFlag, // softmaxLseFlag
+            keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode
+            acl_dst_f16_tensor, // attentionOut
+            nullptr // softmaxLse
+        );
+
+        // Step 6: post-processing, permute and cast to f32
+
+        int64_t new_dim[] = {0, 2, 1, 3};
+        aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
+
+        if(ggml_cann_type_mapping(dst->type) != faDataType){
+            ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool());
+            perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize);
+            void* perm_out_f16_buffer = perm_out_f16_allocator.get();
+
+            int64_t* perm_out_f16_ne = dst->ne;
+            size_t  perm_out_f16_nb[GGML_MAX_DIMS];
+            perm_out_f16_nb[0] = faElemSize;
+            for(int i = 1; i < GGML_MAX_DIMS; ++i){
+                perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1];
+            }
+            aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor(
+                perm_out_f16_buffer, faDataType, faElemSize,
+                perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS);
+            aclnn_permute(ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS);
+            aclnn_cast(ctx,
+                acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
+            ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor);
+        }else{
+            // only need to permute
+            aclnn_permute(ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS);
+        }
+        ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
+                                         acl_src1_f16_tensor,
+                                         acl_src2_f16_tensor,
+                                         acl_dst_f16_tensor,
+                                         acl_dst_tensor);
+        if(src3 != nullptr){
+            ggml_cann_release_resources(ctx, bcast_pse_tensor);
+        }
+    }else{
+        GGML_ABORT("Function is not implemented.");
+    }
+}
old mode 100644 (file)
new mode 100755 (executable)
index 15993cc..80ce80b
@@ -714,6 +714,21 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst);
  */
 void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst);
 
+/**
+ * @brief   Performs the Flash Attention extended operator using the CANN backend.
+ *
+ * @details This function implements the memory-efficient Flash Attention algorithm
+ *          for computing scaled dot-product attention with hardware acceleration.
+ *          The result is stored in the destination tensor `dst`.
+ *
+ *          This operation is accelerated using the CANN backend to improve runtime performance.
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the result will be stored.
+ *            dst->op is expected to be `GGML_OP_FLASH_ATTN_EXT`.
+ */
+void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
 /*
  * @brief A generic wrapper for ACL resources with custom deleter support.
  */
old mode 100644 (file)
new mode 100755 (executable)
old mode 100644 (file)
new mode 100755 (executable)
index 605b6a7..c0ea260
@@ -36,6 +36,7 @@
 #include "ggml-backend-impl.h"
 #include "ggml-cann/aclnn_ops.h"
 #include "ggml-cann/common.h"
+#include "ggml.h"
 
 #define GGML_COMMON_DECL_C
 
@@ -1748,6 +1749,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
         case GGML_OP_COUNT_EQUAL:
             ggml_cann_count_equal(ctx, dst);
             break;
+        case GGML_OP_FLASH_ATTN_EXT:
+            ggml_cann_flash_attn_ext(ctx, dst);
+            break;
         default:
             return false;
     }
@@ -2177,6 +2181,38 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
         case GGML_OP_PAD_REFLECT_1D:
         case GGML_OP_COUNT_EQUAL:
             return true;
+        case GGML_OP_FLASH_ATTN_EXT:{
+            // derived from [ggml-cuda.cu]
+            if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
+                return false;
+            }
+            if(op->src[1]->type != GGML_TYPE_F16 && op->src[1]->type != GGML_TYPE_F32 && op->src[1]->type != GGML_TYPE_BF16){
+                return false;
+            }
+            if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){
+                return false;
+            }
+            if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
+                // different head sizes of K and V are not supported yet
+                return false;
+            }
+            if (op->src[0]->ne[0] == 192) {
+                return false;
+            }
+            if (op->src[0]->ne[0] == 576) {
+                // DeepSeek MLA
+                return false;
+            }
+            if (op->src[0]->ne[3] != 1) {
+                return false;
+            }
+            float logitSoftcap = 0.0f;
+            memcpy(&logitSoftcap,  (float*)op->op_params + 2, sizeof(float));
+            if(logitSoftcap != 0.0f) {
+                return false;
+            }
+            return true;
+        }
         default:
             return false;
     }