]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: add BF16 support (llama/11093)
authorJohannes Gäßler <redacted>
Mon, 6 Jan 2025 01:33:52 +0000 (02:33 +0100)
committerGeorgi Gerganov <redacted>
Tue, 14 Jan 2025 08:38:01 +0000 (10:38 +0200)
* CUDA: add BF16 support

ggml/src/ggml-cuda/convert.cu
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/mmv.cu
ggml/src/ggml-cuda/vendors/cuda.h
ggml/src/ggml-cuda/vendors/hip.h
ggml/src/ggml-cuda/vendors/musa.h

index 3896f956d7338f4cc25b7bbc1fe26acbf48d902a..5b0dfacefc9dab78c7fe099b6b171ca1c57714f9 100644 (file)
@@ -680,6 +680,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
             return dequantize_row_iq3_s_cuda;
         case GGML_TYPE_F16:
             return convert_unary_cuda<half>;
+        case GGML_TYPE_BF16:
+            return convert_unary_cuda<nv_bfloat16>;
         default:
             return nullptr;
     }
index c180adc84b861a9ca003a7af1047d8b6444e1c96..0b06be729864e6de01b8d97a5da858f5017a4b3f 100644 (file)
@@ -1728,7 +1728,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
 static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
 
-    bool use_mul_mat_vec   = src0->type == GGML_TYPE_F16
+    bool use_mul_mat_vec   = (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
         && src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
     bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
@@ -2869,6 +2869,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                     case GGML_TYPE_IQ3_XXS:
                     case GGML_TYPE_IQ4_NL:
                     case GGML_TYPE_IQ4_XS:
+                    case GGML_TYPE_BF16:
 #ifdef GGML_USE_MUSA
                         if (a->type == GGML_TYPE_Q3_K) {
                             return false;
index a4b4f6bc10d989ce18bfa551e67d49dbec4ba153..ac45f2d17f1e1fe94ee215c675b6519b12c86a4a 100644 (file)
@@ -1,9 +1,9 @@
 #include "common.cuh"
 #include "mmv.cuh"
 
-template <typename type_acc, int block_size>
+template <typename T, typename type_acc, int block_size>
 static __global__ void mul_mat_vec(
-        const half * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
+        const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
         const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
     const int64_t row     = blockIdx.x;
     const int64_t channel = blockIdx.z;
@@ -13,7 +13,6 @@ static __global__ void mul_mat_vec(
     y   +=  channel               *stride_channel_y;
     dst +=  channel               *stride_channel_dst;
 
-    const half2  * x2 = (const half2  *) x;
     const float2 * y2 = (const float2 *) y;
 
     extern __shared__ char data_mmv[];
@@ -28,28 +27,44 @@ static __global__ void mul_mat_vec(
 
     float sumf;
 
-    if (std::is_same<type_acc, float>::value) {
-        sumf = 0.0f;
+    if constexpr (std::is_same<T, half>::value) {
+        const half2 * x2 = (const half2 *) x;
 
-        for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
-            const float2 tmpx = __half22float2(x2[col2]);
-            const float2 tmpy = y2[col2];
-            sumf += tmpx.x * tmpy.x;
-            sumf += tmpx.y * tmpy.y;
-        }
-    } else {
+        if (std::is_same<type_acc, float>::value) {
+            sumf = 0.0f;
+
+            for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
+                const float2 tmpx = __half22float2(x2[col2]);
+                const float2 tmpy = y2[col2];
+                sumf += tmpx.x * tmpy.x;
+                sumf += tmpx.y * tmpy.y;
+            }
+        } else {
 #ifdef FP16_AVAILABLE
-        half2 sumh2 = make_half2(0.0f, 0.0f);
+            half2 sumh2 = make_half2(0.0f, 0.0f);
 
-        for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
-            const float2 tmp = y2[col2];
-            sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
-        }
+            for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
+                const float2 tmp = y2[col2];
+                sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
+            }
 
-        sumf = __low2float(sumh2) + __high2float(sumh2);
+            sumf = __low2float(sumh2) + __high2float(sumh2);
 #else
-        NO_DEVICE_CODE;
+            NO_DEVICE_CODE;
 #endif // FP16_AVAILABLE
+        }
+    } else if constexpr (std::is_same<T, nv_bfloat16>::value) {
+        const int * x2 = (const int *) x;
+        sumf = 0.0f;
+
+        for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
+            const int    tmpx = x2[col2];
+            const float2 tmpy = y2[col2];
+            sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
+            sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
+        }
+    } else {
+        static_assert(std::is_same<T, void>::value, "unsupported type");
     }
 
     sumf = warp_reduce_sum(sumf);
@@ -71,9 +86,9 @@ static __global__ void mul_mat_vec(
     dst[row] = sumf;
 }
 
-template <typename type_acc>
+template <typename T, typename type_acc>
 static void launch_mul_mat_vec_cuda(
-        const half * x, const float * y, float * dst,
+        const T * x, const float * y, float * dst,
         const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
         cudaStream_t stream) {
@@ -97,35 +112,35 @@ static void launch_mul_mat_vec_cuda(
     const dim3 block_dims(block_size_best, 1, 1);
     switch (block_size_best) {
         case   32: {
-            mul_mat_vec<type_acc,  32><<<block_nums, block_dims, smem, stream>>>
+            mul_mat_vec<T, type_acc,  32><<<block_nums, block_dims, smem, stream>>>
                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
         } break;
         case   64: {
-            mul_mat_vec<type_acc,  64><<<block_nums, block_dims, smem, stream>>>
+            mul_mat_vec<T, type_acc,  64><<<block_nums, block_dims, smem, stream>>>
                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
         } break;
         case   96: {
-            mul_mat_vec<type_acc,  96><<<block_nums, block_dims, smem, stream>>>
+            mul_mat_vec<T, type_acc,  96><<<block_nums, block_dims, smem, stream>>>
                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
         } break;
         case  128: {
-            mul_mat_vec<type_acc, 128><<<block_nums, block_dims, smem, stream>>>
+            mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>>
                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
         } break;
         case  160: {
-            mul_mat_vec<type_acc, 160><<<block_nums, block_dims, smem, stream>>>
+            mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>>
                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
         } break;
         case  192: {
-            mul_mat_vec<type_acc, 192><<<block_nums, block_dims, smem, stream>>>
+            mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>>
                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
         } break;
         case  224: {
-            mul_mat_vec<type_acc, 224><<<block_nums, block_dims, smem, stream>>>
+            mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>>
                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
         } break;
         case  256: {
-            mul_mat_vec<type_acc, 256><<<block_nums, block_dims, smem, stream>>>
+            mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>>
                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
         } break;
         default: {
@@ -134,25 +149,25 @@ static void launch_mul_mat_vec_cuda(
     }
 }
 
+template<typename T>
 static void mul_mat_vec_cuda(
-        const half * x, const float * y, float * dst,
+        const T * x, const float * y, float * dst,
         const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
         enum ggml_prec prec, cudaStream_t stream) {
     switch (prec) {
         case GGML_PREC_DEFAULT: {
-            launch_mul_mat_vec_cuda<half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
+            launch_mul_mat_vec_cuda<T, half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
                 stride_channel_x, stride_channel_y, stride_channel_dst, stream);
         } break;
         case GGML_PREC_F32: {
-            launch_mul_mat_vec_cuda<float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
+            launch_mul_mat_vec_cuda<T, float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
                 stride_channel_x, stride_channel_y, stride_channel_dst, stream);
         } break;
     }
 }
 
 void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->type  == GGML_TYPE_F32);
 
@@ -164,7 +179,6 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
     const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
 
-    const half  * src0_d = (const half  *) src0->data;
     const float * src1_d = (const float *) src1->data;
     float       *  dst_d = (float       *)  dst->data;
 
@@ -181,7 +195,20 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
     const int64_t channel_stride_y   = src1->nb[2] / ggml_type_size(src1->type);
     const int64_t channel_stride_dst =  dst->nb[2] / ggml_type_size( dst->type);
 
-    mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
+    switch (src0->type) {
+        case GGML_TYPE_F16: {
+            const half * src0_d = (const half *) src0->data;
+            mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
+                channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
+        } break;
+        case GGML_TYPE_BF16: {
+            const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
+            mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
+                channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
+        } break;
+        default:
+            GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
+    }
 }
 
 void ggml_cuda_op_mul_mat_vec(
@@ -190,7 +217,6 @@ void ggml_cuda_op_mul_mat_vec(
     const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
     const int64_t src1_padded_row_size, cudaStream_t stream) {
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->type  == GGML_TYPE_F32);
 
@@ -211,8 +237,20 @@ void ggml_cuda_op_mul_mat_vec(
     const int64_t channel_stride_y   = 0;
     const int64_t channel_stride_dst = 0;
 
-    mul_mat_vec_cuda((const half *) src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
-        nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
+    switch (src0->type) {
+        case GGML_TYPE_F16: {
+            const half * src0_d = (const half *) src0_dd_i;
+            mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
+                nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
+        } break;
+        case GGML_TYPE_BF16: {
+            const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
+            mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
+                nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
+        } break;
+        default:
+            GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
+    }
 
     GGML_UNUSED(ctx);
     GGML_UNUSED(src1);
index db9f6a165d07c241a9b555a42f2893a9dfb6af71..1746b073203e3e134faadc618474b43efbe36745 100644 (file)
@@ -3,6 +3,7 @@
 #include <cuda_runtime.h>
 #include <cuda.h>
 #include <cublas_v2.h>
+#include <cuda_bf16.h>
 #include <cuda_fp16.h>
 
 #if CUDART_VERSION < 11020
index 3205534d66f10a2176de1b91465252a0553b5af4..c905b15d76cc3c118564f66ed6295df2d82a5d2f 100644 (file)
@@ -3,6 +3,7 @@
 #include <hip/hip_runtime.h>
 #include <hipblas/hipblas.h>
 #include <hip/hip_fp16.h>
+#include <hip/hip_bfloat16.h>
 #ifdef __HIP_PLATFORM_AMD__
 // for rocblas_initialize()
 #include "rocblas/rocblas.h"
     #define __has_builtin(x) 0
 #endif
 
+typedef hip_bfloat16 nv_bfloat16;
+
 typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
 typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
 static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
index 1604b8229d57fe3deb765dfbef42811cc739099c..6cc1b69ee3390c619c574c3c6e1a871051ba8e94 100644 (file)
@@ -3,6 +3,7 @@
 #include <musa_runtime.h>
 #include <musa.h>
 #include <mublas.h>
+#include <musa_bf16.h>
 #include <musa_fp16.h>
 #define CUBLAS_COMPUTE_16F CUDA_R_16F
 #define CUBLAS_COMPUTE_32F CUDA_R_32F
 #define cudaKernelNodeParams musaKernelNodeParams
 #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
 #define cudaStreamEndCapture musaStreamEndCapture
+
+typedef mt_bfloat16 nv_bfloat16;