From: compilade Date: Thu, 28 Aug 2025 14:11:36 +0000 (-0400) Subject: ggml : fix SSM_SCAN for n_groups > 1 (llama/15625) X-Git-Tag: v0.9.1~134 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=e4556fe51fcae5234ee953049d3bf70b6dfd6476;p=pkg%2Fggml%2Fsources%2Fggml ggml : fix SSM_SCAN for n_groups > 1 (llama/15625) --- diff --git a/src/ggml-cpu/ops.cpp b/src/ggml-cpu/ops.cpp index 93330b43..8c1f7948 100644 --- a/src/ggml-cpu/ops.cpp +++ b/src/ggml-cpu/ops.cpp @@ -9003,8 +9003,7 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); - // allows optimizing the modulo since n_group should be a power of 2 - GGML_ASSERT((ng & -ng) == ng); + GGML_ASSERT(nh % ng == 0); // heads per thread const int dh = (nh + nth - 1)/nth; @@ -9035,6 +9034,7 @@ static void ggml_compute_forward_ssm_scan_f32( // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; const float dA = expf(dt_soft_plus * A[h]); + const int g = h / (nh / ng); // repeat_interleave // dim for (int i1 = 0; i1 < nr; ++i1) { @@ -9057,8 +9057,8 @@ static void ggml_compute_forward_ssm_scan_f32( // TODO: maybe unroll more? for (int j = 0; j < 1; j++) { GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc); - GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc); - GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc); + GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc); + GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc); t0 = GGML_F32_VEC_MUL(t0, adA); t1 = GGML_F32_VEC_MUL(t1, axdt); @@ -9090,8 +9090,8 @@ static void ggml_compute_forward_ssm_scan_f32( for (int i = 0; i < np; i += GGML_F32_STEP) { for (int j = 0; j < GGML_F32_ARR; j++) { ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc); - ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc); - az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc); + ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + g*nc); + az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + g*nc); ax[j] = GGML_F32_VEC_MUL(ax[j], adA); ay[j] = GGML_F32_VEC_MUL(ay[j], axdt); @@ -9113,7 +9113,7 @@ static void ggml_compute_forward_ssm_scan_f32( // d_state for (int i0 = np; i0 < nc; ++i0) { const int i = i0 + ii*nc; - const int ig = i0 + (h & (ng - 1))*nc; + const int ig = i0 + g*nc; // state = prev_state * dA + dB * x const float state = (s0[i] * dA) + (B[ig] * x_dt); // y = rowwise_dotprod(state, C) @@ -9130,6 +9130,7 @@ static void ggml_compute_forward_ssm_scan_f32( for (int h = ih0; h < ih1; ++h) { // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + const int g = h / (nh / ng); // repeat_interleave // dim for (int i1 = 0; i1 < nr; ++i1) { @@ -9144,8 +9145,8 @@ static void ggml_compute_forward_ssm_scan_f32( // TODO: what happens when (d_state % svcntw()) != 0? for (int64_t k = 0; k < nc; k += svcntw()) { svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]); - svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]); - svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]); + svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]); + svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]); svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]); svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA); @@ -9165,7 +9166,7 @@ static void ggml_compute_forward_ssm_scan_f32( // d_state for (int i0 = 0; i0 < nc; ++i0) { const int i = i0 + ii*nc; - const int ig = i0 + (h & (ng - 1))*nc; + const int ig = i0 + g*nc; // state = prev_state * dA + dB * x const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt); // y = rowwise_dotprod(state, C) diff --git a/src/ggml-cuda/ssm-scan.cu b/src/ggml-cuda/ssm-scan.cu index dc9a7d58..6b424381 100644 --- a/src/ggml-cuda/ssm-scan.cu +++ b/src/ggml-cuda/ssm-scan.cu @@ -129,7 +129,7 @@ __global__ void __launch_bounds__(d_state, 1) const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float); const int seq_idx = blockIdx.y; - const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float); + const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float); const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float)); diff --git a/src/ggml-metal/ggml-metal.metal b/src/ggml-metal/ggml-metal.metal index fa80d6e4..4fa16c4a 100644 --- a/src/ggml-metal/ggml-metal.metal +++ b/src/ggml-metal/ggml-metal.metal @@ -1983,14 +1983,15 @@ kernel void kernel_ssm_scan_f32( device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); const int64_t i = i0 + i1*nc; + const int64_t g = ir / (nh / ng); // repeat_interleave float s0 = s0_buff[i]; float s = s_buff[i]; device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); - device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43); - device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53); + device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43); + device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53); device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); for (int64_t i2 = 0; i2 < n_t; ++i2) { @@ -2098,14 +2099,15 @@ kernel void kernel_ssm_scan_f32_group( device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); const int64_t i = i0 + i1*nc; + const int64_t g = ir / (nh / ng); // repeat_interleave float s0 = s0_buff[i]; float s = s_buff[i]; device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); - device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43); - device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53); + device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43); + device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53); device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); for (int64_t i2 = 0; i2 < n_t; ++i2) {