]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Fix CUDA softmax by subtracting max value before exp (#2665)
authorJiahao Li <redacted>
Tue, 22 Aug 2023 18:27:06 +0000 (02:27 +0800)
committerGitHub <redacted>
Tue, 22 Aug 2023 18:27:06 +0000 (20:27 +0200)
ggml-cuda.cu

index 8ab29bb2080249d5a182bd29db3083bd9b42aeef..4fe378c210030fcbbcd8834777a6ac05e9bcbbdb 100644 (file)
@@ -3979,24 +3979,29 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
 
 // the CUDA soft max implementation differs from the CPU implementation
 // instead of doubles floats are used
-// values are also not normalized to the maximum value by subtracting it in the exponential function
-// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
 static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
     const int row = blockDim.x*blockIdx.x + threadIdx.x;
     const int block_size = blockDim.y;
     const int tid = threadIdx.y;
 
-    float tmp = 0.0;
+    float max_val = -INFINITY;
 
-    for (int block_start = 0; block_start < ncols; block_start += block_size) {
-        const int col = block_start + tid;
+    for (int col = tid; col < ncols; col += block_size) {
+        const int i = row*ncols + col;
+        max_val = max(max_val, x[i]);
+    }
 
-        if (col >= ncols) {
-            break;
-        }
+    // find the max value in the block
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
+    }
+
+    float tmp = 0.f;
 
+    for (int col = tid; col < ncols; col += block_size) {
         const int i = row*ncols + col;
-        const float val = expf(x[i]);
+        const float val = expf(x[i] - max_val);
         tmp += val;
         dst[i] = val;
     }
@@ -4007,15 +4012,11 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol
         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
     }
 
-    for (int block_start = 0; block_start < ncols; block_start += block_size) {
-        const int col = block_start + tid;
-
-        if (col >= ncols) {
-            break;
-        }
+    const float inv_tmp = 1.f / tmp;
 
+    for (int col = tid; col < ncols; col += block_size) {
         const int i = row*ncols + col;
-        dst[i] /= tmp;
+        dst[i] *= inv_tmp;
     }
 }