return true;
}
+ if (ops.size() == 2 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_UNARY
+ && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) {
+ const ggml_tensor * ssm_conv = cgraph->nodes[node_idx];
+ const ggml_tensor * silu = cgraph->nodes[node_idx+1];
+
+ if (ssm_conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
+ return false;
+ }
+
+ return true;
+ }
+
+ if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_MUL
+ && unary_ops.size() == 1 && (unary_ops.begin()[0] == GGML_UNARY_OP_SILU || unary_ops.begin()[0] == GGML_UNARY_OP_SIGMOID || unary_ops.begin()[0] == GGML_UNARY_OP_SOFTPLUS)) {
+ const ggml_tensor * unary = cgraph->nodes[node_idx];
+ const ggml_tensor * mul = cgraph->nodes[node_idx+1];
+
+ if (ggml_get_unary_op(unary) != unary_ops.begin()[0]) {
+ return false;
+ }
+
+ if (unary->type != GGML_TYPE_F32 && unary->type != GGML_TYPE_F16) {
+ return false;
+ }
+
+ if (unary->type != mul->type) {
+ return false;
+ }
+
+ const ggml_tensor * other = (mul->src[0] == unary) ? mul->src[1] : mul->src[0];
+ if (other->type != unary->type) {
+ return false;
+ }
+ if (!ggml_is_contiguous_1(other) || !ggml_is_contiguous_1(unary->src[0]) || !ggml_are_same_shape(other, unary)) {
+ return false;
+ }
+
+ return true;
+ }
+
if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
&& unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
const ggml_tensor *scale = cgraph->nodes[node_idx];
continue;
}
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) {
+ ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i+1]);
+ i++;
+ continue;
+ }
+
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SILU }) ||
+ ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SIGMOID }) ||
+ ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SOFTPLUS })) {
+ ggml_cuda_op_unary_mul(*cuda_ctx, node, cgraph->nodes[i+1]);
+ i++;
+ continue;
+ }
+
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
i += 2;
ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
#include "ssm-conv.cuh"
+#include "unary.cuh"
-template <size_t split_d_inner, size_t d_conv>
+template <bool apply_silu, size_t split_d_inner, size_t d_conv>
static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1,
const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
for (size_t j = 0; j < d_conv; j++) {
sumf += x[(i + j) % d_conv] * w[j];
}
- y_block[i * stride_y + tid] = sumf;
+ y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf;
}
}
-template <size_t split_d_inner, size_t d_conv, int64_t split_n_t>
+template <bool apply_silu, size_t split_d_inner, size_t d_conv, int64_t split_n_t>
static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1,
const int src0_nb0, const int src0_nb1, const int src0_nb2,
const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
const int stride_w = src1_nb1 / sizeof(float);
const int stride_y = dst_nb1 / sizeof(float);
- float x[d_conv] = { 0.0f };
- float w[d_conv] = { 0.0f };
+ const int64_t local_n_t = min(split_n_t, n_t - bidz * split_n_t);
+ const int n_cols = d_conv - 1 + split_n_t;
+
+ extern __shared__ float smem[];
+ constexpr int load_cols = d_conv - 1 + split_n_t;
+ constexpr int total_elems = split_d_inner * load_cols;
+ int row = tid / load_cols;
+ int col = tid % load_cols;
#pragma unroll
- for (size_t j = 0; j < d_conv; j++) {
- w[j] = w_block[tid * stride_w + j];
+ for (int idx = tid; idx < total_elems; idx += split_d_inner) {
+ if (row < (int)split_d_inner) {
+ smem[row * n_cols + col] = x_block[row * stride_x + col];
+ }
+
+ col += split_d_inner;
+ row += col / load_cols;
+ col = col % load_cols;
}
+ __syncthreads();
+ // Load weights into registers (done once, small)
+ float w[d_conv] = { 0.0f };
#pragma unroll
- for (int64_t i = 0; i < split_n_t; i++) {
- if (bidz * split_n_t + i < n_t) {
- float sumf = 0.0f;
-
- if (i == 0) {
- for (size_t j = 0; j < d_conv; j++) {
- x[j] = x_block[tid * stride_x + j];
- }
- } else {
- x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
- }
+ for (size_t j = 0; j < d_conv; j++) {
+ w[j] = w_block[tid * stride_w + j];
+ }
+ // Compute from shared memory
+ for (int64_t i = 0; i < local_n_t; i++) {
+ float sumf = 0.0f;
#pragma unroll
- for (size_t j = 0; j < d_conv; j++) {
- sumf += x[(i + j) % d_conv] * w[j];
- }
- y_block[i * stride_y + tid] = sumf;
+ for (size_t j = 0; j < d_conv; j++) {
+ sumf += smem[tid * n_cols + i + j] * w[j];
}
+ y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf;
}
}
+template <bool apply_silu>
static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t,
constexpr int kNC = decltype(NC)::value;
if (n_t <= 32) {
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
- ssm_conv_f32<threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
+ ssm_conv_f32<apply_silu, threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
} else {
const int64_t split_n_t = 32;
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
- ssm_conv_long_token_f32<threads, kNC, split_n_t><<<blocks, threads, 0, stream>>>(
+ const size_t smem_size = threads * (kNC - 1 + split_n_t) * sizeof(float);
+ ssm_conv_long_token_f32<apply_silu, threads, kNC, split_n_t><<<blocks, threads, smem_size, stream>>>(
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
}
};
}
}
-void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst) {
const struct ggml_tensor * src0 = dst->src[0]; // conv_x
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
+ const bool fuse_silu = silu_dst != nullptr;
+
+ // When fusing, write to silu_dst (the node downstream references).
+ const struct ggml_tensor * out = fuse_silu ? silu_dst : dst;
const int64_t nc = src1->ne[0]; // d_conv
const int64_t nr = src0->ne[1]; // d_inner
- const int64_t n_t = dst->ne[1]; // tokens per sequence
- const int64_t n_s = dst->ne[2]; // number of sequences in the batch
+ const int64_t n_t = out->ne[1]; // tokens per sequence
+ const int64_t n_s = out->ne[2]; // number of sequences in the batch
- GGML_ASSERT(dst->ne[0] == nr);
+ GGML_ASSERT(out->ne[0] == nr);
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
- float * dst_d = (float *) dst->data;
+ float * dst_d = (float *) out->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
- ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1],
- dst->nb[2], nc, nr, n_t, n_s, stream);
+ GGML_ASSERT(out->type == GGML_TYPE_F32);
+ if (fuse_silu) {
+ ssm_conv_f32_cuda<true>(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
+ out->nb[2], nc, nr, n_t, n_s, stream);
+ } else {
+ ssm_conv_f32_cuda<false>(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
+ out->nb[2], nc, nr, n_t, n_s, stream);
+ }
}
leaky_relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), negative_slope, stream);
}
}
+
+/* fused unary + mul */
+
+template <float (*op)(float)>
+static void ggml_cuda_op_unary_mul_impl(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node) {
+ // unary_node: UNARY op applied to unary_node->src[0]
+ // mul_node: MUL(a, b) where one of a/b is unary_node
+ // Output goes to mul_node->data
+
+ const ggml_tensor * unary_src = unary_node->src[0]; // input to the unary op
+ const ggml_tensor * other_src = (mul_node->src[0] == unary_node) ? mul_node->src[1] : mul_node->src[0];
+
+ GGML_ASSERT(ggml_is_contiguous_1(unary_src));
+ GGML_ASSERT(unary_src->nb[0] == ggml_element_size(unary_src));
+ GGML_ASSERT(ggml_is_contiguous_1(other_src));
+ GGML_ASSERT(other_src->nb[0] == ggml_element_size(other_src));
+ GGML_ASSERT(ggml_are_same_shape(unary_src, other_src));
+
+ GGML_ASSERT(unary_src->type == GGML_TYPE_F32 || unary_src->type == GGML_TYPE_F16);
+ GGML_ASSERT(unary_src->type == other_src->type);
+ GGML_ASSERT(unary_src->type == mul_node->type);
+
+ cudaStream_t stream = ctx.stream();
+
+ const int64_t k = ggml_nelements(mul_node);
+ const int64_t nc = unary_src->ne[0];
+ const int64_t unary_stride = unary_src->nb[1];
+ const int64_t other_stride = other_src->nb[1];
+
+ if (unary_src->type == GGML_TYPE_F16) {
+ unary_gated_cuda<op>((const half *) unary_src->data, (const half *) other_src->data,
+ (half *) mul_node->data, k, nc,
+ unary_stride / sizeof(half), other_stride / sizeof(half), stream);
+ } else {
+ unary_gated_cuda<op>((const float *) unary_src->data, (const float *) other_src->data,
+ (float *) mul_node->data, k, nc,
+ unary_stride / sizeof(float), other_stride / sizeof(float), stream);
+ }
+}
+
+void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node) {
+ switch (ggml_get_unary_op(unary_node)) {
+ case GGML_UNARY_OP_SILU:
+ ggml_cuda_op_unary_mul_impl<op_silu>(ctx, unary_node, mul_node);
+ break;
+ case GGML_UNARY_OP_SIGMOID:
+ ggml_cuda_op_unary_mul_impl<op_sigmoid>(ctx, unary_node, mul_node);
+ break;
+ case GGML_UNARY_OP_SOFTPLUS:
+ ggml_cuda_op_unary_mul_impl<op_softplus>(ctx, unary_node, mul_node);
+ break;
+ default:
+ GGML_ABORT("Unsupported unary op for fused unary+mul");
+ }
+}