]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CANN: implement the SSM_CONV operator (llama/17737)
author0Marble <redacted>
Fri, 26 Dec 2025 01:12:04 +0000 (09:12 +0800)
committerGeorgi Gerganov <redacted>
Wed, 31 Dec 2025 10:39:43 +0000 (12:39 +0200)
* CANN: implement SSM_CONV operator

Co-authored-by: Aleksei Lobanov, <redacted>
Co-authored-by: Sujin Kang, <redacted>
* CANN: remove custom error limit for SSM_CONV

* CANN: merge SSM_CONV tensor shape/strides into one line

---------

Co-authored-by: Sujin Kang, <redacted>
src/ggml-cann/aclnn_ops.cpp
src/ggml-cann/aclnn_ops.h
src/ggml-cann/ggml-cann.cpp

index 3688abdd5896b0d0cdee020673aa16fe80fb1371..2180a06fd00123fcafe82ea5515cc428fa6ec7e9 100644 (file)
@@ -3702,3 +3702,106 @@ void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
             break;
     }
 }
+
+void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
+    ggml_tensor * src0 = dst->src[0];  // conv_x
+    ggml_tensor * src1 = dst->src[1];  // conv1d.weight
+
+    // This op is currently defined only for F32 in ggml_cpu
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    // Shapes follow ggml_compute_forward_ssm_conv_f32
+    const int64_t nc  = src1->ne[0];   // d_conv
+    const int64_t ncs = src0->ne[0];   // d_conv - 1 + n_t
+    const int64_t nr  = src0->ne[1];   // d_inner
+    const int64_t n_s = src0->ne[2];   // n_seqs
+
+    const int64_t n_t = dst->ne[1];    // tokens per sequence
+
+    GGML_ASSERT(dst->ne[0] == nr);     // dst: {d_inner, n_t, n_s}
+    GGML_ASSERT(src1->ne[1] == nr);    // weight: {d_conv, d_inner}
+    GGML_ASSERT(ncs == nc - 1 + n_t);  // conv_x: {d_conv - 1 + n_t, d_inner, n_s}
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT(src1->nb[0] == sizeof(float));
+
+    // --- Build CANN tensors ---
+
+    // 1) Input: conv_x as NCL
+    //
+    // src0->ne = { ncs, nr, n_s, 1 }  // {L_in, C, N}
+    // Passing ACL_FORMAT_NCL here means:
+    //   reversed dims -> [N, C, L_in] = [n_s, nr, ncs]
+    acl_tensor_ptr acl_x = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL);
+
+    // 2) Weights: depthwise conv kernel, view src1 as {K, 1, C}
+    //
+    // src1 original:   ne = { nc, nr, 1, 1 }  // [K, C, 1, 1]
+    // we want a view:  ne_w = { nc, 1, nr }   // [K, 1, C]
+    // so that reversed dims -> [C, 1, K] which matches
+    //   [out_channels, in_channels/groups, kernel_size]
+    int64_t w_ne[GGML_MAX_DIMS] = { nc, 1, nr, 1 }; // [K, 1 input ch. per group, C groups]
+    // Layout: src1 data is [K, C] with
+    //   offset(k, c) = k*nb0 + c*nb1
+    // We want offset_w(k, 0, c) = k*nb0 + c*nb1,
+    // so we can reuse nb0 and nb1, and set nb2 = nb1.
+    size_t  w_nb[GGML_MAX_DIMS] = { src1->nb[0], src1->nb[1], src1->nb[1], src1->nb[3] }; // same as src1
+
+    acl_tensor_ptr acl_w = ggml_cann_create_tensor(
+        src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), w_ne, w_nb, 3, ACL_FORMAT_NCL);
+
+    // 3) Output: dst is { d_inner, n_t, n_s } (CLN)
+    //
+    // We need an NCL view of the same buffer:
+    //   desired NCL logical shape: { L_out = n_t, C = nr, N = n_s }
+    //
+    // Original CLN layout:
+    //   dst->ne = { nr, n_t, n_s }
+    //   dst->nb[0] = sizeof(float)
+    //   dst->nb[1] = nr * sizeof(float)
+    //   dst->nb[2] = nr * n_t * sizeof(float)
+    //
+    // We want offset_new(L, C, N) = offset_orig(C, L, N).
+    // Choose:
+    //   nb_y[0] = nr * sizeof(float);           // step in L
+    //   nb_y[1] = sizeof(float);                // step in C
+    //   nb_y[2] = nr * n_t * sizeof(float);     // step in N
+    int64_t y_ne[GGML_MAX_DIMS] = { n_t, nr, n_s, 1 }; // [L_out, C, N]
+    size_t  y_nb[GGML_MAX_DIMS] = { dst->ne[0] * sizeof(float), sizeof(float), dst->ne[0] * dst->ne[1] * sizeof(float), dst->nb[3] }; // [nr, 1, nr * n_t]
+
+    acl_tensor_ptr acl_y = ggml_cann_create_tensor(
+        dst->data, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), y_ne, y_nb, 3, ACL_FORMAT_NCL);
+
+    // --- Conv1d parameters: depthwise, stride 1, no padding ("valid") ---
+    int64_t strideVal[1]   = { 1 };
+    int64_t paddingVal[1]  = { 0 };
+    int64_t dilationVal[1] = { 1 };
+
+    acl_int_array_ptr stride   = ggml_cann_create_int_array(strideVal, 1);
+    acl_int_array_ptr padding  = ggml_cann_create_int_array(paddingVal, 1);
+    acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1);
+
+    const bool    transposed   = false;
+    const int64_t groups       = nr;  // depthwise: one group per inner dim
+    int8_t        cubeMathType = 0;
+
+#ifdef ASCEND_310P
+    cubeMathType = 1;
+#endif
+
+    GGML_CANN_CALL_ACLNN_OP(ctx,
+                            Convolution,
+                            acl_x.get(),    // input:  N, C, L_in = ncs
+                            acl_w.get(),    // weight: [C, 1, K] with groups=nr
+                            nullptr,        // bias
+                            stride.get(),
+                            padding.get(),
+                            dilation.get(),
+                            transposed,
+                            padding.get(),   // output padding (unused for non-transposed)
+                            groups,
+                            acl_y.get(),
+                            cubeMathType);
+}
+
index 6ec46402894d355d2237920086c72fa434a6bcd4..a6ea016c542780c80f65e039e5dfb3e82f28b99e 100644 (file)
@@ -1033,6 +1033,8 @@ void ggml_cann_op_unary(std::function<void(ggml_backend_cann_context &, aclTenso
                         ggml_backend_cann_context &                                                ctx,
                         ggml_tensor *                                                              dst);
 
+void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst);
+
 /**
  * @brief Applies a gated (GLU-style) unary operation using the CANN backend.
  *
index e90759f98b3a6511dd9be0b331545b20098e0fee..ef23ec78da69340e9baef1cc9ece73626bf3ed7e 100644 (file)
@@ -1888,6 +1888,8 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
             break;
         case GGML_OP_OUT_PROD:
             ggml_cann_out_prod(ctx, dst);
+        case GGML_OP_SSM_CONV:
+            ggml_cann_ssm_conv(ctx, dst);
             break;
         default:
             return false;
@@ -2471,6 +2473,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
                 }
                 return true;
             }
+        case GGML_OP_SSM_CONV:
+            return true;
         default:
             return false;
     }