return res;
}
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched(ggml_metal_library_t lib, const ggml_tensor * op, int ssm_conv_bs) {
+ GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
+ GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
+ GGML_ASSERT(ggml_is_contiguous(op->src[1]));
+
+ char base[256];
+ char name[256];
+
+ const char * suffix = "";
+ if (op->src[1]->ne[0] % 4 == 0) {
+ suffix = "_4";
+ }
+
+ snprintf(base, 256, "kernel_ssm_conv_%s_%s_batched%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
+ snprintf(name, 256, "%s_ssm_conv_bs=%d", base, ssm_conv_bs);
+
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+ if (!res.pipeline) {
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+ ggml_metal_cv_set_int16(cv, ssm_conv_bs, FC_SSM_CONV + 0);
+
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+ ggml_metal_cv_free(cv);
+ }
+
+ return res;
+}
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}
- res.smem = 32*sizeof(float)*nsg;
+ // Shared memory layout:
+ // - sgptg * NW floats for partial sums (nsg * 32)
+ // - sgptg floats for shared_x_dt (nsg)
+ // - sgptg floats for shared_dA (nsg)
+ // Total: nsg * (32 + 2) floats
+ res.smem = (32 + 2)*sizeof(float)*nsg;
return res;
}
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
#define FC_MUL_MV 600
#define FC_MUL_MM 700
#define FC_ROPE 800
+#define FC_SSM_CONV 900
// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPTG 8
/*.nb2 =*/ nb2,
};
- auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
+ // Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead
+ const bool use_batched = (ne1 > 1);
+
+ if (use_batched) {
+ // Determine the smallest power of 2 that's >= ne1, but <= 256
+ int BATCH_SIZE;
+ if (ne1 > 128) BATCH_SIZE = 256;
+ else if (ne1 > 64 ) BATCH_SIZE = 128;
+ else if (ne1 > 32 ) BATCH_SIZE = 64;
+ else if (ne1 > 16 ) BATCH_SIZE = 32;
+ else if (ne1 > 8 ) BATCH_SIZE = 16;
+ else if (ne1 > 4 ) BATCH_SIZE = 8;
+ else BATCH_SIZE = 2;
+
+ auto pipeline = ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op, BATCH_SIZE);
- ggml_metal_encoder_set_pipeline(enc, pipeline);
- ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
- ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
+ // Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences
+ // Each threadgroup has BATCH_SIZE threads, each handling one token
+ const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE;
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1);
+ } else {
+ auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
+ }
return 1;
}
x[0] = sumf;
}
+constant short FC_ssm_conv_bs [[function_constant(FC_SSM_CONV + 0)]];
+
+// Batched version: each threadgroup processes multiple tokens for better efficiency
+// Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens
+kernel void kernel_ssm_conv_f32_f32_batched(
+ constant ggml_metal_kargs_ssm_conv & args,
+ device const void * src0,
+ device const void * src1,
+ device float * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ // tgpig.x = row index (ir)
+ // tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
+ // tgpig.z = sequence index (i3)
+ // tpitg.x = thread within batch (0..BATCH_SIZE-1)
+ const short BATCH_SIZE = FC_ssm_conv_bs;
+
+ const int64_t ir = tgpig.x;
+ const int64_t i2_base = tgpig.y * BATCH_SIZE;
+ const int64_t i3 = tgpig.z;
+ const int64_t i2_off = tpitg.x;
+ const int64_t i2 = i2_base + i2_off;
+
+ const int64_t nc = args.ne10; // conv kernel size (typically 4)
+ const int64_t n_t = args.ne1; // number of tokens
+
+ // Bounds check for partial batches at the end
+ if (i2 >= n_t) {
+ return;
+ }
+
+ // Load conv weights (shared across all tokens for this row)
+ device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
+
+ // Load source for this specific token
+ device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
+
+ // Output location for this token
+ device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
+
+ float sumf = 0.0f;
+ for (int64_t i0 = 0; i0 < nc; ++i0) {
+ sumf += s[i0] * c[i0];
+ }
+
+ x[0] = sumf;
+}
+
+kernel void kernel_ssm_conv_f32_f32_batched_4(
+ constant ggml_metal_kargs_ssm_conv & args,
+ device const void * src0,
+ device const void * src1,
+ device float * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ // tgpig.x = row index (ir)
+ // tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
+ // tgpig.z = sequence index (i3)
+ // tpitg.x = thread within batch (0..BATCH_SIZE-1)
+ const short BATCH_SIZE = FC_ssm_conv_bs;
+
+ const int64_t ir = tgpig.x;
+ const int64_t i2_base = tgpig.y * BATCH_SIZE;
+ const int64_t i3 = tgpig.z;
+ const int64_t i2_off = tpitg.x;
+ const int64_t i2 = i2_base + i2_off;
+
+ const int64_t nc = args.ne10; // conv kernel size (typically 4)
+ const int64_t n_t = args.ne1; // number of tokens
+
+ // Bounds check for partial batches at the end
+ if (i2 >= n_t) {
+ return;
+ }
+
+ // Load conv weights (shared across all tokens for this row)
+ device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
+
+ // Load source for this specific token
+ device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
+
+ // Output location for this token
+ device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
+
+ float sumf = 0.0f;
+ for (int64_t i0 = 0; i0 < nc/4; ++i0) {
+ sumf += dot(s[i0], c[i0]);
+ }
+
+ x[0] = sumf;
+}
+
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
+// Optimized version: reduces redundant memory loads by having one thread load shared values
kernel void kernel_ssm_scan_f32(
constant ggml_metal_kargs_ssm_scan & args,
device const void * src0,
uint3 tgpg[[threadgroups_per_grid]]) {
constexpr short NW = N_SIMDWIDTH;
- shared[tpitg.x] = 0.0f;
+ // Shared memory layout:
+ // [0..sgptg*NW-1]: partial sums for reduction (existing)
+ // [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch
+ // [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch
+ threadgroup float * shared_sums = shared;
+ threadgroup float * shared_x_dt = shared + sgptg * NW;
+ threadgroup float * shared_dA = shared + sgptg * NW + sgptg;
+
+ shared_sums[tpitg.x] = 0.0f;
const int32_t i0 = tpitg.x;
const int32_t i1 = tgpig.x;
for (int i2 = 0; i2 < n_t; i2 += sgptg) {
threadgroup_barrier(mem_flags::mem_threadgroup);
- for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
- const float dt0 = dt[0];
+ // Pre-compute x_dt and dA for this batch of tokens
+ // Only first sgptg threads do the loads and expensive math
+ if (i0 < sgptg && i2 + i0 < n_t) {
+ // ns12 and ns21 are element strides (nb12/nb10, nb21/nb20)
+ device const float * x_t = x + i0 * args.ns12;
+ device const float * dt_t = dt + i0 * args.ns21;
+
+ const float dt0 = dt_t[0];
const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
- const float x_dt = x[0] * dtsp;
- const float dA = exp(dtsp * A0);
+ shared_x_dt[i0] = x_t[0] * dtsp;
+ shared_dA[i0] = dtsp; // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
+ const float x_dt = shared_x_dt[t];
+ const float dA = exp(shared_dA[t] * A0);
s = (s0 * dA) + (B[i0] * x_dt);
const float sumf = simd_sum(s * C[i0]);
if (tiisg == 0) {
- shared[t*NW + sgitg] = sumf;
+ shared_sums[t*NW + sgitg] = sumf;
}
// recurse
s0 = s;
- x += args.ns12;
- dt += args.ns21;
B += args.ns42;
C += args.ns52;
}
+ // Advance pointers for next batch
+ x += sgptg * args.ns12;
+ dt += sgptg * args.ns21;
+
threadgroup_barrier(mem_flags::mem_threadgroup);
- const float sumf = simd_sum(shared[sgitg*NW + tiisg]);
+ const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]);
if (tiisg == 0 && i2 + sgitg < n_t) {
y[sgitg*nh*nr] = sumf;
}
}
+ // Examples from granite-4.0-h-1b/ggml-model-Q8_0.gguf
+ test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1})); // prefill
+ test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 3328, 1, 1}, {4, 3328, 1, 1})); // generate
+ test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 512, 1)); // prefill
+ test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 1, 1)); // generate
+
+
return test_cases;
}