]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
cuda : optimize argmax (llama/10441)
authorDiego Devesa <redacted>
Thu, 21 Nov 2024 17:18:50 +0000 (18:18 +0100)
committerGeorgi Gerganov <redacted>
Tue, 3 Dec 2024 19:05:37 +0000 (21:05 +0200)
* cuda : optimize argmax

* remove unused parameter

ggml-ci

* fixup : use full warps

ggml-ci

* Apply suggestions from code review

Co-authored-by: Johannes Gäßler <redacted>
* fix ub

* ggml : check ne00 <= INT32_MAX in argmax and argsort

---------

Co-authored-by: Johannes Gäßler <redacted>
src/ggml-cuda/argmax.cu
src/ggml-cuda/common.cuh
src/ggml-cuda/quantize.cu
src/ggml.c

index aab04eca7a38522f5a3f4a4e01be64646169e698..5340eedc08916cdaace562b248ef54d900ce2021 100644 (file)
@@ -1,57 +1,69 @@
-#include "common.cuh"
+#include <algorithm>
+#include <cstdint>
+
 #include "argmax.cuh"
+#include "common.cuh"
 #include "sum.cuh"
 
-#include <cstdint>
+static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __restrict__ dst, const int64_t ncols) {
+    const int64_t row = blockIdx.x;
 
-static __global__ void argmax_f32(
-    const float * x, int32_t * dst, const int64_t ncols, const int64_t nrows) {
+    float maxval = -FLT_MAX;
+    int   argmax = -1;
+    const float * rowx = x + row * ncols;
 
-    int argmax_thread = 0;
-    const int64_t row0 = (int64_t)blockIdx.x*WARP_SIZE;
+    for (int32_t col = threadIdx.x; col < ncols; col += blockDim.x) {
+        const float val = rowx[col];
+        if (val > maxval) {
+            maxval = val;
+            argmax = col;
+        }
+    }
 
 #pragma unroll
-    for (int64_t row1 = 0; row1 < WARP_SIZE; ++row1) {
-        const int64_t row = row0 + row1;
-
-        if (row >= nrows) {
-            break;
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
+        const int   col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
+        if (val > maxval) {
+            maxval = val;
+            argmax = col;
         }
+    }
 
-        float maxval = -FLT_MAX;
-        int   argmax = -1;
-
-        for (int32_t col = threadIdx.x; col < ncols; col += WARP_SIZE) {
-            const float val        = x[row*ncols + col];
-            const int   bigger     = val > maxval;
-            const int   not_bigger = bigger ^ 0x00000001;
-
-            maxval = maxval*not_bigger + val*bigger;
-            argmax = argmax*not_bigger + col*bigger;
+    const int n_warps = blockDim.x / WARP_SIZE;
+    const int lane_id = threadIdx.x % WARP_SIZE;
+    const int warp_id = threadIdx.x / WARP_SIZE;
+    if (n_warps > 1) {
+        constexpr int    max_warps = 1024 / WARP_SIZE;
+        __shared__ float shared_maxval[max_warps];
+        __shared__ int   shared_argmax[max_warps];
+        if (lane_id == 0) {
+            shared_maxval[warp_id] = maxval;
+            shared_argmax[warp_id] = argmax;
         }
 
+        __syncthreads();
+
+        if (warp_id == 0) {
+            if (lane_id < n_warps) {
+                maxval = shared_maxval[lane_id];
+                argmax = shared_argmax[lane_id];
+            }
 #pragma unroll
-        for (int mask = 16; mask > 0; mask >>= 1) {
-            const float val        = __shfl_xor_sync(0xFFFFFFFF, maxval, mask, WARP_SIZE);
-            const int   col        = __shfl_xor_sync(0xFFFFFFFF, argmax, mask, WARP_SIZE);
-            const int   bigger     = val > maxval;
-            const int   not_bigger = bigger ^ 0x00000001;
-
-            maxval = maxval*not_bigger + val*bigger;
-            argmax = argmax*not_bigger + col*bigger;
+            for (int offset = 16; offset > 0; offset >>= 1) {
+                const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
+                const int   col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
+                if (val > maxval) {
+                    maxval = val;
+                    argmax = col;
+                }
+            }
         }
-
-        const int store = row1 == threadIdx.x;
-        argmax_thread += store*argmax;
     }
 
-    const int row = row0 + threadIdx.x;
-
-    if (row >= nrows) {
-        return;
+    if (warp_id == 0 && lane_id == 0) {
+        dst[row] = argmax;
     }
-
-    dst[row] = argmax_thread;
 }
 
 void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -70,10 +82,10 @@ void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
     cudaStream_t stream = ctx.stream();
 
-    const int64_t num_blocks = (nrows + WARP_SIZE - 1) / WARP_SIZE;
-
-    const dim3 blocks_dim(WARP_SIZE, 1, 1);
+    const int64_t num_blocks = nrows;
+    const int64_t num_threads = std::min<int64_t>(1024, (ne00 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE);
+    const dim3 blocks_dim(num_threads, 1, 1);
     const dim3 blocks_num(num_blocks, 1, 1);
 
-    argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00, nrows);
+    argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00);
 }
index e146c691c6f873769635c509981d16f82a199971..b0dd16066b4ba5a46199305f1a2e8d15752ff30c 100644 (file)
@@ -180,8 +180,8 @@ static __device__ __forceinline__ int warp_reduce_sum(int x) {
     return __reduce_add_sync(0xffffffff, x);
 #else
 #pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        x += __shfl_xor_sync(0xffffffff, x, mask, 32);
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        x += __shfl_xor_sync(0xffffffff, x, offset, 32);
     }
     return x;
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
@@ -189,17 +189,17 @@ static __device__ __forceinline__ int warp_reduce_sum(int x) {
 
 static __device__ __forceinline__ float warp_reduce_sum(float x) {
 #pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        x += __shfl_xor_sync(0xffffffff, x, mask, 32);
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        x += __shfl_xor_sync(0xffffffff, x, offset, 32);
     }
     return x;
 }
 
 static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
 #pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
-        a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        a.x += __shfl_xor_sync(0xffffffff, a.x, offset, 32);
+        a.y += __shfl_xor_sync(0xffffffff, a.y, offset, 32);
     }
     return a;
 }
@@ -209,16 +209,16 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
 
 #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 #pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32);
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        const half2 a_other = __shfl_xor_sync(0xffffffff, a, offset, 32);
         reinterpret_cast<half&>(a.x) +=  __low2half(a_other);
         reinterpret_cast<half&>(a.y) += __high2half(a_other);
     }
     return a;
 #else
 #pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, 32));
     }
     return a;
 #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
@@ -231,8 +231,8 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
 
 static __device__ __forceinline__ float warp_reduce_max(float x) {
 #pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
     }
     return x;
 }
@@ -275,8 +275,8 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal
 static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
 #pragma unroll
-   for (int mask = 16; mask > 0; mask >>= 1) {
-       x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
+   for (int offset = 16; offset > 0; offset >>= 1) {
+       x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
    }
    return x;
 #else
index 45408ce8684e431806b4735abe5ddc0a7d3ac267..1702e4ce2feba476c0637f32d3280df81cf2f4c4 100644 (file)
@@ -69,8 +69,8 @@ static __global__ void quantize_mmq_q8_1(
 
     // Exchange max. abs. value between vals_per_scale/4 threads.
 #pragma unroll
-    for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) {
-        amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
+    for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) {
+        amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE));
     }
 
     float sum;
@@ -79,8 +79,8 @@ static __global__ void quantize_mmq_q8_1(
 
         // Exchange calculate sum across vals_per_sum/4 threads.
 #pragma unroll
-        for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) {
-            sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE);
+        for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {
+            sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);
         }
     }
 
index 719d75c70c75d29e9d3f00e17df8ed9e423e5337..78e7874dee04d496c7c9e7d69f7f4a9ace2691d2 100644 (file)
@@ -2255,6 +2255,7 @@ struct ggml_tensor * ggml_argmax(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
     GGML_ASSERT(ggml_is_matrix(a));
+    GGML_ASSERT(a->ne[0] <= INT32_MAX);
 
     struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, a->ne[1]);
 
@@ -4138,6 +4139,7 @@ struct ggml_tensor * ggml_argsort(
         struct ggml_context  * ctx,
         struct ggml_tensor   * a,
         enum ggml_sort_order   order) {
+    GGML_ASSERT(a->ne[0] <= INT32_MAX);
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
 
     ggml_set_op_params_i32(result, 0, (int32_t) order);