]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
HIP: Cleanup hipification header (#15285)
authoruvos <redacted>
Thu, 14 Aug 2025 14:23:56 +0000 (16:23 +0200)
committerGitHub <redacted>
Thu, 14 Aug 2025 14:23:56 +0000 (16:23 +0200)
add expicit conversion operator to support older versions of rocm
Switch over to hip_bf16 from legacy hip_bfloat16
Simplify RDNA3 define
Reduce swap over of new hipblas api to rocm 6.5 as this version is used for rocm 7.0 previews

---------

Co-authored-by: Johannes Gäßler <redacted>
ggml/src/ggml-cuda/convert.cu
ggml/src/ggml-cuda/convert.cuh
ggml/src/ggml-cuda/cpy-utils.cuh
ggml/src/ggml-cuda/getrows.cu
ggml/src/ggml-cuda/mmvf.cu
ggml/src/ggml-cuda/set-rows.cu
ggml/src/ggml-cuda/vendors/hip.h

index e3beddbc1b23b6c1831ba9b9232572b265f92d27..8f0efdcc1260b621ed981d89e424689db19dc52c 100644 (file)
@@ -31,8 +31,8 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
     dequantize_kernel(vx, ib, iqs, v);
 
     const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
-    y[iy0 + 0]        = float(v.x);
-    y[iy0 + y_offset] = float(v.y);
+    y[iy0 + 0]        = ggml_cuda_cast<dst_t>(v.x);
+    y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
 }
 
 template <bool need_check>
@@ -630,7 +630,7 @@ static __global__ void convert_unary(
 
     const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
     const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00;
-    y[iy] = float(x[ix]);
+    y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
 }
 
 template <typename src_t, typename dst_t>
index f04214be175ba5e1662333ff46e120dcadd38032..c62e8a1b1040a771435cfd71c2fd29e5945b7b36 100644 (file)
@@ -29,3 +29,16 @@ typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t;
 to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type);
 to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
 to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);
+
+template<typename dst_t, typename src_t>
+ __host__ __device__ inline dst_t ggml_cuda_cast(src_t x) {
+    if constexpr (std::is_same_v<dst_t, src_t>) {
+        return x;
+    } else if constexpr(std::is_same_v<dst_t, nv_bfloat16>) {
+        return __float2bfloat16(float(x));
+    } else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
+        return __bfloat162float(x);
+    } else {
+        return float(x);
+    }
+}
index 410c12b7ba56b1f44ca08751c9cf5ba0417aaeb3..e621cb9811ab62a12d32a848be1de66afdb77725 100644 (file)
@@ -1,15 +1,7 @@
 #pragma once
 
 #include "ggml-common.h"
-
-template<typename src_t, typename dst_t>
-static __device__ __forceinline__ void convert_flt(const src_t * src, dst_t * dst) {
-    if constexpr (std::is_same_v<src_t, dst_t>) {
-        *dst = *src;
-    } else {
-        *dst = float(*src);
-    }
-}
+#include "convert.cuh"
 
 static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
     if (x <= val[0]) return 0;
@@ -221,5 +213,5 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
 
 template<typename src_t, typename dst_t>
 static __device__ void cpy_1_flt(const char * cxi, char * cdsti) {
-    convert_flt((const src_t *)cxi, (dst_t *)cdsti);
+    *(dst_t *) cdsti = ggml_cuda_cast<dst_t>(*(const src_t *) cxi);
 }
index f77b2629a19b078b606b49d3a17aa3e135034237..68d3254fbe4723ee77e5b17f153c588980bc0442 100644 (file)
@@ -1,5 +1,6 @@
 #include "getrows.cuh"
 #include "dequantize.cuh"
+#include "convert.cuh"
 
 template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
 static __global__ void k_get_rows(
@@ -34,8 +35,8 @@ static __global__ void k_get_rows(
     dfloat2 v;
     dequantize_kernel(src0_row, ib, iqs, v);
 
-    dst_row[iybs + iqs + 0]        = float(v.x);
-    dst_row[iybs + iqs + y_offset] = float(v.y);
+    dst_row[iybs + iqs + 0]        = ggml_cuda_cast<dst_t>(v.x);
+    dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t>(v.y);
 }
 
 template<typename src0_t, typename dst_t>
@@ -62,7 +63,7 @@ static __global__ void k_get_rows_float(
     dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
     const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
 
-    dst_row[i00] = float(src0_row[i00]);
+    dst_row[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
 }
 
 template<typename grad_t, typename dst_t>
index 1ad4bc75ba6145766f37e595425844d51d57ae7e..16100b680456a5c4a3be9e9485576469c02e495c 100644 (file)
@@ -1,5 +1,6 @@
 #include "ggml.h"
 #include "common.cuh"
+#include "convert.cuh"
 #include "mmvf.cuh"
 
 template <typename T, typename type_acc, int ncols_dst, int block_size>
@@ -93,8 +94,8 @@ static __global__ void mul_mat_vec_f(
 #pragma unroll
             for (int j = 0; j < ncols_dst; ++j) {
                 const float2 tmpy = y2[j*stride_col_y2 + col2];
-                sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
-                sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
+                sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
+                sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
             }
         }
     } else {
index 07983436459d42eac501e5eb9987105af94233d1..b4115a43c2a3296103af86baa41a36eca8f0b8f5 100644 (file)
@@ -3,11 +3,6 @@
 
 typedef void (*set_rows_kernel_t)(const char * src, char * dst);
 
-template<typename src_t, typename dst_t>
-__device__ __forceinline__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {
-    convert_flt(src_f, dst_f);
-}
-
 // Generic quantized set_rows kernel template
 template<typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
 static __global__ void k_set_rows_quant(
@@ -117,9 +112,7 @@ static __global__ void k_set_rows(
     const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
     dst_t * dst_row_ptr    = dst + dst_row*s1 + i02*s2 + i03*s3;
 
-    const src_t* src_elem = src0_row + i00;
-    dst_t* dst_elem = dst_row_ptr + i00;
-    set_rows_1(src_elem, dst_elem);
+    dst_row_ptr[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
 
     GGML_UNUSED(ne10);
     GGML_UNUSED(ne13);
index ec1b59caafc9aa98541fc8638b5b203193ec9ce0..6e9c67aca096e68c136c2103b5d178fff603ad60 100644 (file)
@@ -4,7 +4,7 @@
 #include <hip/hip_runtime.h>
 #include <hipblas/hipblas.h>
 #include <hip/hip_fp16.h>
-#include <hip/hip_bfloat16.h>
+#include <hip/hip_bf16.h>
 
 #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
 #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
 #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
 #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
 
-#if HIP_VERSION >= 70000000
+#if HIP_VERSION >= 60500000
 #define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F
 #define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F
 #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F
 #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
 #define cublasComputeType_t hipblasDatatype_t
 #define cudaDataType_t hipblasDatatype_t
-#endif // HIP_VERSION >= 7000000
+#endif // HIP_VERSION >= 6050000
 
 #if !defined(__HIP_PLATFORM_AMD__)
 #error "The HIP backend supports only AMD targets"
 #define RDNA4
 #endif
 
-#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
-    defined(__gfx1150__) || defined(__gfx1151__)
+#if defined(__GFX11__)
 #define RDNA3
 #endif
 
     #define __has_builtin(x) 0
 #endif
 
-typedef hip_bfloat16 nv_bfloat16;
-typedef short2 nv_bfloat162; // FIXME there is no 2x BF16 type being defined in bfloat16.h, ad-hoc compilation fix
+typedef __hip_bfloat16 nv_bfloat16;
+typedef __hip_bfloat162 nv_bfloat162;
 
 typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
 typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));