]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: non-contiguous (RMS) norm support (llama/11659)
authorJohannes Gäßler <redacted>
Tue, 4 Feb 2025 21:21:42 +0000 (22:21 +0100)
committerGeorgi Gerganov <redacted>
Thu, 27 Feb 2025 06:55:36 +0000 (08:55 +0200)
* CUDA: non-contiguous (RMS) norm support

---------

Co-authored-by: Georgi Gerganov <redacted>
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/norm.cu
ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-vulkan/ggml-vulkan.cpp

index bda10aec1180ab7a3f60798982a0764bb1e41b62..70a5980998c4ab4b0e87254d34197a30b5529e0a 100644 (file)
@@ -38,6 +38,7 @@
 #include "ggml-cuda/upscale.cuh"
 #include "ggml-cuda/wkv6.cuh"
 #include "ggml-cuda/gla.cuh"
+#include "ggml.h"
 
 #include <algorithm>
 #include <array>
@@ -3139,6 +3140,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             break;
         case GGML_OP_NORM:
         case GGML_OP_RMS_NORM:
+            return true;
         case GGML_OP_RMS_NORM_BACK:
             return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
             break;
@@ -3181,7 +3183,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_SUM_ROWS:
         case GGML_OP_ARGSORT:
         case GGML_OP_ACC:
+            return true;
         case GGML_OP_GROUP_NORM:
+            return ggml_is_contiguous(op->src[0]);
         case GGML_OP_UPSCALE:
         case GGML_OP_PAD:
         case GGML_OP_ARANGE:
index d991ec972813ffb583e89944898a113bc1c3548d..f127616eddade711fd7c200c58ac4e479d9ee89c 100644 (file)
@@ -1,12 +1,20 @@
 #include "norm.cuh"
+#include <cstdint>
 
 template <int block_size>
-static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) {
-    const int row = blockIdx.x*blockDim.y + threadIdx.y;
-    const int tid = threadIdx.x;
+static __global__ void norm_f32(
+        const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
+        const int64_t stride_sample, const float eps) {
+    const int nrows     = gridDim.x;
+    const int nchannels = gridDim.y;
 
-    x   += int64_t(row)*ncols;
-    dst += int64_t(row)*ncols;
+    const int row       = blockIdx.x;
+    const int channel   = blockIdx.y;
+    const int sample    = blockIdx.z;
+    const int tid       = threadIdx.x;
+
+    x   += sample*stride_sample + channel*stride_channel + row*stride_row;
+    dst += ((sample*nchannels + channel)*nrows + row)*ncols;
 
     float2 mean_var = make_float2(0.0f, 0.0f);
 
@@ -97,12 +105,19 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
 }
 
 template <int block_size>
-static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
-    const int row = blockIdx.x*blockDim.y + threadIdx.y;
-    const int tid = threadIdx.x;
+static __global__ void rms_norm_f32(
+        const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
+        const int64_t stride_sample, const float eps) {
+    const int nrows     = gridDim.x;
+    const int nchannels = gridDim.y;
+
+    const int row       = blockIdx.x;
+    const int channel   = blockIdx.y;
+    const int sample    = blockIdx.z;
+    const int tid       = threadIdx.x;
 
-    x   += int64_t(row)*ncols;
-    dst += int64_t(row)*ncols;
+    x   += sample*stride_sample + channel*stride_channel + row*stride_row;
+    dst += ((sample*nchannels + channel)*nrows + row)*ncols;
 
     float tmp = 0.0f; // partial sum for thread in warp
 
@@ -186,13 +201,16 @@ static __global__ void rms_norm_back_f32(
     }
 }
 
-static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
+static void norm_f32_cuda(
+        const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
+        const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
+    const dim3 blocks_num(nrows, nchannels, nsamples);
     if (ncols < 1024) {
         const dim3 block_dims(WARP_SIZE, 1, 1);
-        norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
+        norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     } else {
         const dim3 block_dims(1024, 1, 1);
-        norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
+        norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     }
 }
 
@@ -207,13 +225,16 @@ static void group_norm_f32_cuda(
     }
 }
 
-static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
+static void rms_norm_f32_cuda(
+        const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
+        const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
+    const dim3 blocks_num(nrows, nchannels, nsamples);
     if (ncols < 1024) {
         const dim3 block_dims(WARP_SIZE, 1, 1);
-        rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
+        rms_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     } else {
         const dim3 block_dims(1024, 1, 1);
-        rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
+        rms_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     }
 }
 
@@ -229,23 +250,26 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float *
 
 void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
+    const float * src0_d = (const float *) src0->data;
+    float * dst_d = (float *) dst->data;
     cudaStream_t stream = ctx.stream();
 
-    GGML_ASSERT(ggml_is_contiguous(src0));
-
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t nrows = ggml_nrows(src0);
+    GGML_TENSOR_UNARY_OP_LOCALS;
 
     float eps;
     memcpy(&eps, dst->op_params, sizeof(float));
     GGML_ASSERT(eps >= 0.0f);
 
-    norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
+    const size_t ts0 = ggml_type_size(src0->type);
+    GGML_ASSERT(nb00 == ts0);
+    const int64_t s01 = nb01 / ts0;
+    const int64_t s02 = nb02 / ts0;
+    const int64_t s03 = nb03 / ts0;
+
+    norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
 }
 
 void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -254,8 +278,6 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
     float * dst_d = (float *)dst->data;
     cudaStream_t stream = ctx.stream();
 
-    GGML_ASSERT(ggml_is_contiguous(src0));
-
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
@@ -271,23 +293,26 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
 
 void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
+    const float * src0_d = (const float *) src0->data;
+    float * dst_d = (float *) dst->data;
     cudaStream_t stream = ctx.stream();
 
-    GGML_ASSERT(ggml_is_contiguous(src0));
-
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t nrows = ggml_nrows(src0);
+    GGML_TENSOR_UNARY_OP_LOCALS;
 
     float eps;
     memcpy(&eps, dst->op_params, sizeof(float));
     GGML_ASSERT(eps >= 0.0f);
 
-    rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
+    const size_t ts0 = ggml_type_size(src0->type);
+    GGML_ASSERT(nb00 == ts0);
+    const int64_t s01 = nb01 / ts0;
+    const int64_t s02 = nb02 / ts0;
+    const int64_t s03 = nb03 / ts0;
+
+    rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
 }
 
 void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
index 9605914ffa430c6f786bfffd94fbca78e27d7887..0a264be371e51bae69d495825d1406b4ddedea38 100644 (file)
@@ -1206,10 +1206,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
         case GGML_OP_GROUP_NORM:
             return has_simdgroup_reduction;
         case GGML_OP_RMS_NORM:
-            return has_simdgroup_reduction && (op->ne[0] % 4 == 0);
+            return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
         case GGML_OP_ARGMAX:
-        case GGML_OP_NORM:
             return true;
+        case GGML_OP_NORM:
+            return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
         case GGML_OP_ROPE:
             {
                 const int mode = ((const int32_t *) op->op_params)[2];
index 9ca3959abf128c1cc4171241ef070c2fb4f0cc2a..48ac489a6554f867f6faa8eba1ced4ef3310c093 100644 (file)
@@ -8182,9 +8182,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_VIEW:
         case GGML_OP_PERMUTE:
         case GGML_OP_TRANSPOSE:
+            return true;
         case GGML_OP_NORM:
         case GGML_OP_GROUP_NORM:
         case GGML_OP_RMS_NORM:
+            return ggml_is_contiguous(op->src[0]);
         case GGML_OP_ADD:
         case GGML_OP_ACC:
         case GGML_OP_MUL: