]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Hide latency of bias and gate-loading (llama/16847)
authorOliver Simons <redacted>
Thu, 30 Oct 2025 03:34:15 +0000 (04:34 +0100)
committerGeorgi Gerganov <redacted>
Sat, 1 Nov 2025 07:41:35 +0000 (09:41 +0200)
This is realised by loading them into registers before computation of
the dot-product, effectively batching them together with said
dot-product. As a lot of threads are alive here, the warp scheduler has
enough threads available to effectively hide the cost of additionally
loading those two floats.

src/ggml-cuda/mmvq.cu

index be04a85cc55154cb10963c6069b4b9eafd6db5c8..07645ad9e71d46e7092f140c83df53cb33aa4ff8 100644 (file)
@@ -190,12 +190,28 @@ static __global__ void mul_mat_vec_q(
 
     const uint32_t channel_bias = ids ? channel_x : channel_dst;
 
+    float x_biases[ncols_dst][rows_per_cuda_block]    = { { 0.0f } };
+    float gate_biases[ncols_dst][rows_per_cuda_block] = { { 0.0f } };
     if constexpr (has_fusion) {
         if (use_bias) {
             x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
+            // 1. Hide latency by prefetching bias and gate here
+            // 2. load only on threads that won't die after partial sum calculation
+            if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
+                (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
+                for (int j = 0; j < ncols_dst; ++j) {
+                    x_biases[j][threadIdx.x] = x_bias[j * stride_col_dst + threadIdx.x];
+                }
+            }
         }
         if (use_gate_bias) {
             gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
+            if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
+                (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
+                for (int j = 0; j < ncols_dst; ++j) {
+                    gate_biases[j][threadIdx.x] = gate_bias[j * stride_col_dst + threadIdx.x];
+                }
+            }
         }
     }
 
@@ -283,12 +299,12 @@ static __global__ void mul_mat_vec_q(
             float result = tmp[j][threadIdx.x];
             if constexpr (has_fusion) {
                 if (use_bias) {
-                    result += x_bias[j*stride_col_dst + threadIdx.x];
+                    result += x_biases[j][threadIdx.x];
                 }
                 if (use_gate) {
                     float gate_value = tmp_gate[j][threadIdx.x];
                     if (use_gate_bias) {
-                        gate_value += gate_bias[j*stride_col_dst + threadIdx.x];
+                        gate_value += gate_biases[j][threadIdx.x];
                     }
                     switch (active_glu) {
                         case GGML_GLU_OP_SWIGLU: