]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: faster softmax via shared memory + fp16 math (llama/4742)
authorJohannes Gäßler <redacted>
Tue, 9 Jan 2024 07:58:55 +0000 (08:58 +0100)
committerGeorgi Gerganov <redacted>
Thu, 11 Jan 2024 19:50:01 +0000 (21:50 +0200)
ggml-cuda.cu

index 9b3df812b4c7da012d3dc77279f385f140bd5dcd..900f7ba4afac4ec43a44487e6138df7051e36dd6 100644 (file)
 #include "ggml.h"
 #include "ggml-backend-impl.h"
 
+#define CC_PASCAL     600
 #define MIN_CC_DP4A   610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
 #define CC_VOLTA      700
 #define CC_OFFSET_AMD 1000000
@@ -556,11 +557,12 @@ static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
 
 struct cuda_device_capabilities {
     int     cc;                 // compute capability
+    size_t  smpb;               // max. shared memory per block
     bool    vmm;                // virtual memory support
     size_t  vmm_granularity;    // granularity of virtual memory
 };
 
-static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} };
+static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, 0, false, 0} };
 
 static void * g_scratch_buffer = nullptr;
 static size_t g_scratch_size = 0; // disabled by default
@@ -593,6 +595,19 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
     return a;
 }
 
+static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
+#if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+    (void) a;
+    bad_arch();
+#else
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
+    }
+    return a;
+#endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+}
+
 static __device__ __forceinline__ float warp_reduce_max(float x) {
 #pragma unroll
     for (int mask = 16; mask > 0; mask >>= 1) {
@@ -601,6 +616,19 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
     return x;
 }
 
+static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
+#if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+    (void) x;
+    bad_arch();
+#else
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
+    }
+    return x;
+#endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+}
+
 static __device__ __forceinline__ float op_repeat(const float a, const float b) {
     return b;
     GGML_UNUSED(a);
@@ -5385,75 +5413,233 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
     dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
 }
 
-static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
+template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
+static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
+    const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
+    const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
+
+    const int tid  = threadIdx.x;
+    const int rowx = blockIdx.x;
+    const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
+
+    const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
+
+    const int warp_id = threadIdx.x / WARP_SIZE;
+    const int lane_id = threadIdx.x % WARP_SIZE;
+
+    extern __shared__ half data_soft_max_f16[];
+    half  * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication
+    // (shared memory) buffer to cache values between iterations:
+    half2 * vals = vals_smem ? (half2 *) (buf_iw + WARP_SIZE) : (half2 *) (dst + rowx*ncols_data);
+    // if the buffer is larger than max. shared memory per block, use dst as temp. buffer instead
+    // in that case col_smem == col_data must be enforced to avoid race conditions
+
+    half2 max_val = make_half2(-INFINITY, -INFINITY);
+
+#pragma unroll
+    for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
+        const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
+        const int col_smem = vals_smem ? col0 + tid : col_data;
+
+        const int ix = rowx*ncols_data + col_data;
+        const int iy = rowy*ncols_data + col_data;
+
+        half2 val;
+        if (need_check && col_data + 0 >= ncols_data) {
+            val.x = -INFINITY;
+        } else {
+            val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f);
+        }
+        if (need_check && col_data + WARP_SIZE >= ncols_data) {
+            val.y = -INFINITY;
+        } else {
+            val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f);
+        }
+        if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) {
+            vals[col_smem] = val;
+        }
+        max_val = __hmax2(max_val, val);
+    }
+
+    // find the max value in the block
+    max_val = warp_reduce_max(max_val);
+    if (block_size > WARP_SIZE) {
+        if (warp_id == 0) {
+            buf_iw[lane_id] = -INFINITY;
+        }
+        __syncthreads();
+
+        if (lane_id == 0) {
+            buf_iw[warp_id] = __hmax(max_val.x, max_val.y);
+        }
+        __syncthreads();
+
+        max_val = __half2half2(buf_iw[lane_id]);
+        max_val = warp_reduce_max(max_val);
+    } else {
+        max_val = __half2half2(__hmax(max_val.x, max_val.y));
+    }
+
+    half2 tmp = make_half2(0.0f, 0.0f); // partial sums
+
+#pragma unroll
+    for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
+        const int col_smem = vals_smem ? col0 + tid : 2*col0 + 2*warp_id*WARP_SIZE + lane_id;
+
+        if (ncols_template == 0 && col_smem >= (vals_smem ? ncols_smem : ncols_data)) {
+            break;
+        }
+
+        const half2 val = h2exp(vals[col_smem] - max_val);
+
+        tmp += val;
+        vals[col_smem] = val;
+    }
+
+    // find the sum of exps in the block
+    tmp = warp_reduce_sum(tmp);
+    if (block_size > WARP_SIZE) {
+        if (warp_id == 0) {
+            buf_iw[lane_id] = 0.0f;
+        }
+        __syncthreads();
+
+        if (lane_id == 0) {
+            buf_iw[warp_id] = tmp.x + tmp.y;
+        }
+        __syncthreads();
+
+        tmp = __half2half2(buf_iw[lane_id]);
+        tmp = warp_reduce_sum(tmp);
+    } else {
+        tmp = __half2half2(tmp.x + tmp.y);
+    }
+
+    const half2 inv_sum = make_half2(1.0f, 1.0f) / tmp;
+
+#pragma unroll
+    for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
+        const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
+        const int col_smem = vals_smem ? col0 + tid : col_data;
+
+        const int idst = rowx*ncols_data + col_data;
+        const half2 result = vals[col_smem] * inv_sum;
+
+        if (need_check && col_data + 0 >= ncols_data) {
+            return;
+        }
+        dst[idst] = result.x;
+
+        if (need_check && col_data + WARP_SIZE >= ncols_data) {
+            return;
+        }
+
+        dst[idst + WARP_SIZE] = result.y;
+    }
+#else
+    (void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale;
+    bad_arch();
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
+}
+
+template <bool vals_smem, int ncols_template, int block_size_template>
+static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
+    const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
+
     const int tid  = threadIdx.x;
     const int rowx = blockIdx.x;
     const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
 
-    const int block_size = blockDim.x;
+    const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
 
     const int warp_id = threadIdx.x / WARP_SIZE;
     const int lane_id = threadIdx.x % WARP_SIZE;
 
-    __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
+    extern __shared__ float data_soft_max_f32[];
+    float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
+    // shared memory buffer to cache values between iterations:
+    float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols;
 
     float max_val = -INFINITY;
 
-    for (int col = tid; col < ncols; col += block_size) {
+#pragma unroll
+    for (int col0 = 0; col0 < ncols; col0 += block_size) {
+        const int col = col0 + tid;
+
+        if (ncols_template == 0 && col >= ncols) {
+            break;
+        }
+
         const int ix = rowx*ncols + col;
         const int iy = rowy*ncols + col;
-        max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
+
+        const float val = x[ix]*scale + (y ? y[iy] : 0.0f);
+        vals[col] = val;
+        max_val = max(max_val, val);
     }
 
     // find the max value in the block
     max_val = warp_reduce_max(max_val);
     if (block_size > WARP_SIZE) {
         if (warp_id == 0) {
-            buf[lane_id] = -INFINITY;
+            buf_iw[lane_id] = -INFINITY;
         }
         __syncthreads();
 
         if (lane_id == 0) {
-            buf[warp_id] = max_val;
+            buf_iw[warp_id] = max_val;
         }
         __syncthreads();
 
-        max_val = buf[lane_id];
+        max_val = buf_iw[lane_id];
         max_val = warp_reduce_max(max_val);
     }
 
-    float tmp = 0.f;
+    float tmp = 0.0f; // partial sum
 
-    for (int col = tid; col < ncols; col += block_size) {
-        const int ix = rowx*ncols + col;
-        const int iy = rowy*ncols + col;
-        const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
+#pragma unroll
+    for (int col0 = 0; col0 < ncols; col0 += block_size) {
+        const int col = col0 + tid;
+
+        if (ncols_template == 0 && col >= ncols) {
+            break;
+        }
+
+        const float val = expf(vals[col] - max_val);
         tmp += val;
-        dst[ix] = val;
+        vals[col] = val;
     }
 
     // find the sum of exps in the block
     tmp = warp_reduce_sum(tmp);
     if (block_size > WARP_SIZE) {
         if (warp_id == 0) {
-            buf[lane_id] = 0.f;
+            buf_iw[lane_id] = 0.0f;
         }
         __syncthreads();
 
         if (lane_id == 0) {
-            buf[warp_id] = tmp;
+            buf_iw[warp_id] = tmp;
         }
         __syncthreads();
 
-        tmp = buf[lane_id];
+        tmp = buf_iw[lane_id];
         tmp = warp_reduce_sum(tmp);
     }
 
-    const float inv_tmp = 1.f / tmp;
+    const float inv_sum = 1.0f / tmp;
 
-    for (int col = tid; col < ncols; col += block_size) {
-        const int i = rowx*ncols + col;
-        dst[i] *= inv_tmp;
+#pragma unroll
+    for (int col0 = 0; col0 < ncols; col0 += block_size) {
+        const int col = col0 + tid;
+
+        if (ncols_template == 0 && col >= ncols) {
+            return;
+        }
+
+        const int idst = rowx*ncols + col;
+        dst[idst] = vals[col] * inv_sum;
     }
 }
 
@@ -6752,12 +6938,90 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
     diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
 }
 
+static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
+    int nth = WARP_SIZE;
+    while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
+    const dim3 block_dims(nth,     1, 1);
+    const dim3 block_nums(nrows_x, 1, 1);
+    const size_t shmem = (GGML_PAD(ncols_x, 2*WARP_SIZE) + WARP_SIZE)*sizeof(half);
+    static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
+    if (shmem <= g_device_caps[g_main_device].smpb) {
+        switch (ncols_x) {
+            case 32:
+                soft_max_f16<true, 32, 32, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                break;
+            case 64:
+                soft_max_f16<true, 64, 32, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                break;
+            case 128:
+                soft_max_f16<true, 128, 64, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                break;
+            case 256:
+                soft_max_f16<true, 256, 128, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                break;
+            case 512:
+                soft_max_f16<true, 512, 256, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                break;
+            case 1024:
+                soft_max_f16<true, 1024, 512, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                break;
+            case 2048:
+                soft_max_f16<true, 2048, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                break;
+            case 4096:
+                soft_max_f16<true, 4096, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                break;
+            default:
+                soft_max_f16<true, 0, 0, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                break;
+        }
+    } else {
+        const size_t shmem_low = WARP_SIZE*sizeof(half);
+        soft_max_f16<false, 0, 0, true><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+    }
+}
+
 static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
     int nth = WARP_SIZE;
     while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
     const dim3 block_dims(nth,     1, 1);
     const dim3 block_nums(nrows_x, 1, 1);
-    soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+    const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
+    static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
+    if (shmem < g_device_caps[g_main_device].smpb) {
+        switch (ncols_x) {
+            case 32:
+                soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                break;
+            case 64:
+                soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                break;
+            case 128:
+                soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                break;
+            case 256:
+                soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                break;
+            case 512:
+                soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                break;
+            case 1024:
+                soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                break;
+            case 2048:
+                soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                break;
+            case 4096:
+                soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                break;
+            default:
+                soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                break;
+        }
+    } else {
+        const size_t shmem_low = WARP_SIZE*sizeof(float);
+        soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+    }
 }
 
 static void im2col_f32_f16_cuda(const float* x, half* dst,
@@ -7072,6 +7336,7 @@ void ggml_init_cublas() {
 #else
             g_device_caps[id].cc = 100*prop.major + 10*prop.minor;
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+            g_device_caps[id].smpb = prop.sharedMemPerBlock;
         }
         for (int id = 0; id < g_device_count; ++id) {
             g_tensor_split[id] /= total_vram;
@@ -8087,7 +8352,21 @@ static void ggml_cuda_op_soft_max(
     float scale = 1.0f;
     memcpy(&scale, dst->op_params, sizeof(float));
 
-    soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+    const bool use_f16_soft_max = false;
+#else
+#ifdef GGML_CUDA_F16
+    const bool use_f16_soft_max = true;
+#else
+    const bool use_f16_soft_max = false;
+#endif // GGML_CUDA_F16
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+
+    if (use_f16_soft_max) {
+        soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
+    } else {
+        soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
+    }
 
     (void) dst;
 }