]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: replace GGML_CUDA_F16 with CUDA arch checks (llama/15433)
authorJohannes Gäßler <redacted>
Wed, 20 Aug 2025 14:58:49 +0000 (16:58 +0200)
committerGeorgi Gerganov <redacted>
Fri, 5 Sep 2025 09:53:59 +0000 (12:53 +0300)
CMakeLists.txt
src/ggml-cuda/CMakeLists.txt
src/ggml-cuda/common.cuh
src/ggml-cuda/convert.cu
src/ggml-cuda/cpy.cu
src/ggml-cuda/dequantize.cuh
src/ggml-cuda/getrows.cu
src/ggml-cuda/ggml-cuda.cu
src/ggml-cuda/vecdotq.cuh
src/ggml-musa/CMakeLists.txt

index 90e274ccdbccfbac64c8dbc0555ef2beddee9dfd..2ead001e2c61082d7332bca3cc21b02a87606944 100644 (file)
@@ -158,7 +158,6 @@ option(GGML_CUDA                            "ggml: use CUDA"
 option(GGML_MUSA                            "ggml: use MUSA"                                  OFF)
 option(GGML_CUDA_FORCE_MMQ                  "ggml: use mmq kernels instead of cuBLAS"         OFF)
 option(GGML_CUDA_FORCE_CUBLAS               "ggml: always use cuBLAS instead of mmq kernels"  OFF)
-option(GGML_CUDA_F16                        "ggml: use 16 bit floats for some calculations"   OFF)
 set   (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
                                             "ggml: max. batch size for using peer access")
 option(GGML_CUDA_NO_PEER_COPY               "ggml: do not use peer to peer copies"            OFF)
index bce07ac362830c164ecb9aebcfd92617944a6845..ea824965aae2d17bd1faeb6b422b3e54dd7a6141 100644 (file)
@@ -24,12 +24,6 @@ if (CUDAToolkit_FOUND)
         #     for best performance and to also build real architectures for the most commonly used GPUs.
         if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6" AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.24")
             set(CMAKE_CUDA_ARCHITECTURES "native")
-        elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
-            if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
-                set(CMAKE_CUDA_ARCHITECTURES "60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;89-real")
-            else()
-                set(CMAKE_CUDA_ARCHITECTURES "60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real")
-            endif()
         else()
             if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
                 set(CMAKE_CUDA_ARCHITECTURES "50-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;89-real")
@@ -91,10 +85,6 @@ if (CUDAToolkit_FOUND)
         add_compile_definitions(GGML_CUDA_NO_FA)
     endif()
 
-    if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
-        add_compile_definitions(GGML_CUDA_F16)
-    endif()
-
     if (GGML_CUDA_NO_PEER_COPY)
         add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
     endif()
index 76ace816ff6fcdca389d0f06e7e83fb69604ed8b..767ad83f60eb50942ede80e12ec0fc3947dff381 100644 (file)
@@ -206,14 +206,6 @@ static const char * cu_get_error_str(CUresult err) {
 #define GGML_CUDA_ASSUME(x)
 #endif // CUDART_VERSION >= 11010
 
-#ifdef GGML_CUDA_F16
-typedef half dfloat; // dequantize float
-typedef half2 dfloat2;
-#else
-typedef float dfloat; // dequantize float
-typedef float2 dfloat2;
-#endif // GGML_CUDA_F16
-
 #if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
 #define GGML_USE_VMM
 #endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
@@ -559,7 +551,7 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
 #endif // CUDART_VERSION >= 12050
 }
 
-typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
+typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v);
 
 static __device__ __forceinline__ float get_alibi_slope(
     const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1
index 8f0efdcc1260b621ed981d89e424689db19dc52c..7a8b6fdf5f493f73bb1573486775fc685bdd1cd8 100644 (file)
@@ -27,7 +27,7 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
     const int64_t y_offset = qr == 1 ? 1 : qk/2;
 
     // dequantize
-    dfloat2 v;
+    float2 v;
     dequantize_kernel(vx, ib, iqs, v);
 
     const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
index f9bb025643ca2e1600844d2cb4661952ddf20aae..0380784ab49186c6783e95c616bde9f820356ff5 100644 (file)
@@ -42,7 +42,7 @@ static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
 
 #pragma unroll
     for (int j = 0; j < QK8_0; j += 2) {
-        dfloat2 dq;
+        float2 dq;
         dequantize_q8_0(cxi, 0, j, dq);
         *(cdstf + j) = dq.x;
         *(cdstf + j + 1) = dq.y;
@@ -55,7 +55,7 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
 
 #pragma unroll
     for (int j = 0; j < qk/2; j++) {
-        dfloat2 dq;
+        float2 dq;
         dequant(cxi, 0, j, dq);
         *(cdstf + j) = dq.x;
         *(cdstf + j + qk/2) = dq.y;
index bd3c2d9db94639f87b92a7cc52185a633e8e8cb5..e060fb29fdc03d22aab88c1aca8af44062f41f45 100644 (file)
@@ -1,48 +1,37 @@
 #include "common.cuh"
 
-static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
+static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
     const block_q4_0 * x = (const block_q4_0 *) vx;
 
-    const dfloat d = x[ib].d;
+    const float d = x[ib].d;
 
     const int vui = x[ib].qs[iqs];
 
     v.x = vui & 0xF;
     v.y = vui >> 4;
 
-#ifdef GGML_CUDA_F16
-    v = __hsub2(v, {8.0f, 8.0f});
-    v = __hmul2(v, {d, d});
-#else
     v.x = (v.x - 8.0f) * d;
     v.y = (v.y - 8.0f) * d;
-#endif // GGML_CUDA_F16
 }
 
-static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
+static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, float2 & v){
     const block_q4_1 * x = (const block_q4_1 *) vx;
 
-    const dfloat d = __low2half(x[ib].dm);
-    const dfloat m = __high2half(x[ib].dm);
+    const float2 dm = __half22float2(x[ib].dm);
 
     const int vui = x[ib].qs[iqs];
 
     v.x = vui & 0xF;
     v.y = vui >> 4;
 
-#ifdef GGML_CUDA_F16
-    v = __hmul2(v, {d, d});
-    v = __hadd2(v, {m, m});
-#else
-    v.x = (v.x * d) + m;
-    v.y = (v.y * d) + m;
-#endif // GGML_CUDA_F16
+    v.x = (v.x * dm.x) + dm.y;
+    v.y = (v.y * dm.x) + dm.y;
 }
 
-static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
+static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
     const block_q5_0 * x = (const block_q5_0 *) vx;
 
-    const dfloat d = x[ib].d;
+    const float d = x[ib].d;
 
     uint32_t qh;
     memcpy(&qh, x[ib].qh, sizeof(qh));
@@ -53,20 +42,14 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
     v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
     v.y = ((x[ib].qs[iqs] >>  4) | xh_1);
 
-#ifdef GGML_CUDA_F16
-    v = __hsub2(v, {16.0f, 16.0f});
-    v = __hmul2(v, {d, d});
-#else
     v.x = (v.x - 16.0f) * d;
     v.y = (v.y - 16.0f) * d;
-#endif // GGML_CUDA_F16
 }
 
-static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
+static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, float2 & v){
     const block_q5_1 * x = (const block_q5_1 *) vx;
 
-    const dfloat d = __low2half(x[ib].dm);
-    const dfloat m = __high2half(x[ib].dm);
+    const float2 dm = __half22float2(x[ib].dm);
 
     uint32_t qh;
     memcpy(&qh, x[ib].qh, sizeof(qh));
@@ -77,27 +60,18 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in
     v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
     v.y = ((x[ib].qs[iqs] >>  4) | xh_1);
 
-#ifdef GGML_CUDA_F16
-    v = __hmul2(v, {d, d});
-    v = __hadd2(v, {m, m});
-#else
-    v.x = (v.x * d) + m;
-    v.y = (v.y * d) + m;
-#endif // GGML_CUDA_F16
+    v.x = (v.x * dm.x) + dm.y;
+    v.y = (v.y * dm.x) + dm.y;
 }
 
-static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
+static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
     const block_q8_0 * x = (const block_q8_0 *) vx;
 
-    const dfloat d = x[ib].d;
+    const float d = x[ib].d;
 
     v.x = x[ib].qs[iqs + 0];
     v.y = x[ib].qs[iqs + 1];
 
-#ifdef GGML_CUDA_F16
-    v = __hmul2(v, {d, d});
-#else
     v.x *= d;
     v.y *= d;
-#endif // GGML_CUDA_F16
 }
index 68d3254fbe4723ee77e5b17f153c588980bc0442..3ec0e957ab5542bf866226e447d4d9eeaf275f95 100644 (file)
@@ -32,7 +32,7 @@ static __global__ void k_get_rows(
     const int y_offset = qr == 1 ? 1 : qk/2;
 
     // dequantize
-    dfloat2 v;
+    float2 v;
     dequantize_kernel(src0_row, ib, iqs, v);
 
     dst_row[iybs + iqs + 0]        = ggml_cuda_cast<dst_t>(v.x);
index d6402a8daaccf50c162cf5c3b5b652d7baa3e0c5..8b706752bc7a0a4f2fc3bd268da51f53a5c40f34 100644 (file)
@@ -3672,10 +3672,6 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t
         features.push_back({ "NO_PEER_COPY", "1" });
     #endif
 
-    #ifdef GGML_CUDA_F16
-        features.push_back({ "F16", "1" });
-    #endif
-
     #ifdef GGML_CUDA_USE_GRAPHS
         features.push_back({ "USE_GRAPHS", "1" });
     #endif
index d8f9aa5ba62242041835e1a4425e27210f227903..d60292b83b1067dd202ba3baa15b1485a9f5ac4f 100644 (file)
@@ -87,7 +87,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp
         sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi);
     }
 
-#ifdef GGML_CUDA_F16
+#ifdef FAST_FP16_AVAILABLE
     const float2 tmp = __half22float2(__hmul2(dm4, ds8));
     const float d4d8 = tmp.x;
     const float m4s8 = tmp.y;
@@ -96,7 +96,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp
     const float2 ds8f = __half22float2(ds8);
     const float d4d8 = dm4f.x * ds8f.x;
     const float m4s8 = dm4f.y * ds8f.y;
-#endif // GGML_CUDA_F16
+#endif // FAST_FP16_AVAILABLE
 
     // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
     return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
@@ -158,7 +158,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
         sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
     }
 
-#ifdef GGML_CUDA_F16
+#ifdef FAST_FP16_AVAILABLE
     const float2 tmp = __half22float2(__hmul2(dm5, ds8));
     const float d5d8 = tmp.x;
     const float m5s8 = tmp.y;
@@ -167,7 +167,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
     const float2 ds8f = __half22float2(ds8);
     const float d5d8 = dm5f.x * ds8f.x;
     const float m5s8 = dm5f.y * ds8f.y;
-#endif // GGML_CUDA_F16
+#endif // FAST_FP16_AVAILABLE
 
     // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it
     return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
@@ -201,7 +201,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
         sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
     }
 
-#ifdef GGML_CUDA_F16
+#ifdef FAST_FP16_AVAILABLE
     const float2 tmp = __half22float2(__hmul2(dm8, ds8));
     const float d8d8 = tmp.x;
     const float m8s8 = tmp.y;
@@ -210,7 +210,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
     const float2 ds8f = __half22float2(ds8);
     const float d8d8 = dm8f.x * ds8f.x;
     const float m8s8 = dm8f.y * ds8f.y;
-#endif // GGML_CUDA_F16
+#endif // FAST_FP16_AVAILABLE
 
     // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
     return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
index 02904526ade04575d99b31bbb140883b1c67b2ce..cdb3818c786c793a9ee4d02c0e4098ca520184fb 100644 (file)
@@ -96,10 +96,6 @@ if (MUSAToolkit_FOUND)
         add_compile_definitions(GGML_CUDA_NO_FA)
     endif()
 
-    if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
-        add_compile_definitions(GGML_CUDA_F16)
-    endif()
-
     if (GGML_CUDA_NO_PEER_COPY)
         add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
     endif()