]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: use registers instead of smem in topk-moe (#16647)
authorAman Gupta <redacted>
Sat, 18 Oct 2025 09:52:53 +0000 (17:52 +0800)
committerGitHub <redacted>
Sat, 18 Oct 2025 09:52:53 +0000 (11:52 +0200)
Uses the technique used in the vulkan PR #16641. Neat trick!

ggml/src/ggml-cuda/topk-moe.cu

index afe4aee2403b24a23d6df28c99dc4c897e4fcda8..c588da2bb9e9325b1472359bc8ec9afafe96311d 100644 (file)
@@ -73,8 +73,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
 
     float wt_sum = 0.f;
 
-    extern __shared__ float data_topk_shared[];
-    float *                 wt_shared_ptr = data_topk_shared + threadIdx.y * n_expert_used;
+    float output_weights[experts_per_thread];
 
     for (int k = 0; k < n_expert_used; k++) {
         float max_val    = wt[0];
@@ -99,11 +98,14 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
             }
         }
 
+        if ((k & (WARP_SIZE - 1)) == threadIdx.x) {
+            output_weights[k / WARP_SIZE] = max_val;
+        }
+
         if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
             wt[max_expert / WARP_SIZE] = -INFINITY;
 
-            wt_shared_ptr[k] = max_val;
-            ids[k]           = max_expert;
+            ids[k] = max_expert;
             if constexpr (with_norm) {
                 wt_sum += max_val;
             }
@@ -115,12 +117,16 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
         const float inv_sum = 1.0f / wt_sum;
 
         for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
-            wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum;
+            output_weights[i] *= inv_sum;
         }
     }
 
-    for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
-        weights[i] = wt_shared_ptr[i];
+#pragma unroll
+    for (int i = 0; i < experts_per_thread; i++) {
+        const int idx = i * WARP_SIZE + threadIdx.x;
+        if (idx < n_expert_used) {
+            weights[idx] = output_weights[i];
+        }
     }
 }
 
@@ -137,48 +143,46 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
     dim3         block_dims(WARP_SIZE, rows_per_block, 1);
     cudaStream_t stream = ctx.stream();
 
-    const int nbytes_shared = n_expert_used * rows_per_block * sizeof(float);
-
     switch (n_expert) {
         case 1:
             topk_moe_cuda<1, with_norm>
-                <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
             break;
         case 2:
             topk_moe_cuda<2, with_norm>
-                <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
             break;
         case 4:
             topk_moe_cuda<4, with_norm>
-                <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
             break;
         case 8:
             topk_moe_cuda<8, with_norm>
-                <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
             break;
         case 16:
             topk_moe_cuda<16, with_norm>
-                <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
             break;
         case 32:
             topk_moe_cuda<32, with_norm>
-                <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
             break;
         case 64:
             topk_moe_cuda<64, with_norm>
-                <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
             break;
         case 128:
             topk_moe_cuda<128, with_norm>
-                <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
             break;
         case 256:
             topk_moe_cuda<256, with_norm>
-                <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
             break;
         case 512:
             topk_moe_cuda<512, with_norm>
-                <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
             break;
         default:
             GGML_ASSERT(false && "fatal error");