From: Aman Gupta Date: Mon, 15 Sep 2025 09:35:11 +0000 (+0800) Subject: CUDA: some micro-optimizations in mmf.cuh for mul_mat_id (llama/15926) X-Git-Tag: v0.9.1~27 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=455c66a7323308fa0cf3c4d008f0e1ac09b784fa;p=pkg%2Fggml%2Fsources%2Fggml CUDA: some micro-optimizations in mmf.cuh for mul_mat_id (llama/15926) --- diff --git a/src/ggml-cuda/mmf.cuh b/src/ggml-cuda/mmf.cuh index bf724bc5..61e3bf30 100644 --- a/src/ggml-cuda/mmf.cuh +++ b/src/ggml-cuda/mmf.cuh @@ -57,31 +57,33 @@ static __global__ void mul_mat_f( T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded); if constexpr (has_ids) { - __shared__ int has_any; - if (threadIdx.y == 0) { - int local_has_any = 0; - for (int j = threadIdx.x; j < cols_per_block; j += warp_size) { - int slot = -1; - for (int k = 0; k < nchannels_dst; ++k) { - const int idv = ids[j*stride_row_id + k*stride_col_id]; - if (idv == expert_idx) { - slot = k; - break; - } - } - if (j < cols_per_block) { - local_has_any |= (slot >= 0); - slot_map[j] = slot; + int found = 0; + + for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { + const int j = j0 + threadIdx.y; + const int32_t * __restrict__ id_row = ids + j*stride_row_id; + + if (threadIdx.x == 0) { + slot_map[j] = -1; + } + + for (int k = threadIdx.x; k < nchannels_dst; k += warp_size) { + int match = id_row[k*stride_col_id] == expert_idx; + + if (match) { + slot_map[j] = k; + found = 1; + break; } } - has_any = warp_reduce_any(local_has_any); } - __syncthreads(); - if (has_any == 0) { + + if (!__syncthreads_or(found)) { return; } } + for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) { tile_A A[ntA][warp_size / tile_A::J]; #pragma unroll @@ -106,14 +108,7 @@ static __global__ void mul_mat_f( if constexpr (!has_ids) { tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f; } else { - float val = 0.0f; - if (j < cols_per_block) { - const int slot = slot_map[j]; - if (slot >= 0) { - val = y[slot*stride_channel_y + j*stride_col_y + col]; - } - } - tile_xy[j0*tile_k_padded + threadIdx.x] = val; + tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f; } } } else if constexpr (std::is_same_v || std::is_same_v) { @@ -125,14 +120,7 @@ static __global__ void mul_mat_f( const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f); tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; } else { - float2 tmp = make_float2(0.0f, 0.0f); - if (j < cols_per_block) { - const int slot = slot_map[j]; - if (slot >= 0) { - const float2 * y2_slot = (const float2 *)(y + slot*stride_channel_y); - tmp = y2_slot[j*stride_col_y + col]; - } - } + float2 tmp = j < cols_per_block && slot_map[j] >= 0 ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f); tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; } } @@ -221,7 +209,7 @@ static inline void mul_mat_f_switch_ids( const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) { if (ids) { mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } else {