]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda : replace asserts in wrong architecture checks with __trap (#4556)
authorslaren <redacted>
Thu, 21 Dec 2023 17:02:30 +0000 (18:02 +0100)
committerGitHub <redacted>
Thu, 21 Dec 2023 17:02:30 +0000 (18:02 +0100)
* cuda : replace asserts in wrong architecture checks with __trap

* make bad_arch noreturn, remove returns

ggml-cuda.cu

index 28d3787843800d9cbe2f0f090d44cd5c3a517a46..e7c9dee4560635efd4f7b89cb1da74e41c494ba6 100644 (file)
@@ -512,6 +512,14 @@ static size_t g_scratch_offset = 0;
 
 static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
 
+[[noreturn]]
+static __device__ void bad_arch() {
+    printf("ERROR: ggml-cuda was compiled without support for the current GPU architecture.\n");
+    __trap();
+
+    (void) bad_arch; // suppress unused function warning
+}
+
 static __device__ __forceinline__ float warp_reduce_sum(float x) {
 #pragma unroll
     for (int mask = 16; mask > 0; mask >>= 1) {
@@ -1972,8 +1980,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_imp
     // second part effectively subtracts 8 from each quant value
     return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y);
 #else
-    assert(false);
-    return 0.0f; // only to satisfy the compiler
+    bad_arch();
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
@@ -2010,8 +2017,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp
     // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
     return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
 #else
-    assert(false);
-    return 0.0f; // only to satisfy the compiler
+    bad_arch();
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
@@ -2046,8 +2052,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_imp
     // second part effectively subtracts 16 from each quant value
     return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y);
 #else
-    assert(false);
-    return 0.0f; // only to satisfy the compiler
+    bad_arch();
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
@@ -2092,8 +2097,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
     return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
 
 #else
-    assert(false);
-    return 0.0f; // only to satisfy the compiler
+    bad_arch();
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
@@ -2114,8 +2118,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_imp
 
     return d8_0*d8_1 * sumi;
 #else
-    assert(false);
-    return 0.0f; // only to satisfy the compiler
+    bad_arch();
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
@@ -2145,8 +2148,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
     // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
     return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
 #else
-    assert(false);
-    return 0.0f; // only to satisfy the compiler
+    bad_arch();
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
@@ -2181,8 +2183,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
 
     return dm2f.x*sumf_d - dm2f.y*sumf_m;
 #else
-    assert(false);
-    return 0.0f; // only to satisfy the compiler
+    bad_arch();
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
@@ -2219,8 +2220,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
 
     return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m);
 #else
-    assert(false);
-    return 0.0f; // only to satisfy the compiler
+    bad_arch();
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
@@ -2260,8 +2260,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
 
     return d3 * sumf;
 #else
-    assert(false);
-    return 0.0f; // only to satisfy the compiler
+    bad_arch();
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
@@ -2286,8 +2285,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
 
     return d3*d8 * sumi;
 #else
-    assert(false);
-    return 0.0f; // only to satisfy the compiler
+    bad_arch();
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
@@ -2320,8 +2318,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
     return dm4f.x*sumf_d - dm4f.y*sumf_m;
 
 #else
-    assert(false);
-    return 0.0f; // only to satisfy the compiler
+    bad_arch();
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
@@ -2354,8 +2351,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
     return dm4f.x*sumf_d - dm4f.y*sumf_m;
 
 #else
-    assert(false);
-    return 0.0f; // only to satisfy the compiler
+    bad_arch();
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
@@ -2395,8 +2391,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
     return dm5f.x*sumf_d - dm5f.y*sumf_m;
 
 #else
-    assert(false);
-    return 0.0f; // only to satisfy the compiler
+    bad_arch();
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
@@ -2429,8 +2424,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
     return dm4f.x*sumf_d - dm4f.y*sumf_m;
 
 #else
-    assert(false);
-    return 0.0f; // only to satisfy the compiler
+    bad_arch();
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
@@ -2460,8 +2454,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
 
     return d*sumf;
 #else
-    assert(false);
-    return 0.0f; // only to satisfy the compiler
+    bad_arch();
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
@@ -2492,8 +2485,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
     return d6 * sumf_d;
 
 #else
-    assert(false);
-    return 0.0f; // only to satisfy the compiler
+    bad_arch();
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
@@ -3359,8 +3351,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
     return dall * sumf_d - dmin * sumf_m;
 
 #else
-    assert(false);
-    return 0.0f; // only to satisfy the compiler
+    bad_arch();
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 
 #endif
@@ -3543,8 +3534,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
     return d * sumf_d;
 
 #else
-    assert(false);
-    return 0.0f; // only to satisfy the compiler
+    bad_arch();
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 
 #endif
@@ -3954,7 +3944,7 @@ template <bool need_check> static __global__ void
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 #else
     (void) vec_dot_q4_0_q8_1_mul_mat;
-    assert(false);
+    bad_arch();
 #endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
@@ -4023,7 +4013,7 @@ template <bool need_check> static __global__ void
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 #else
     (void) vec_dot_q4_1_q8_1_mul_mat;
-    assert(false);
+    bad_arch();
 #endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
@@ -4090,7 +4080,7 @@ template <bool need_check> static __global__ void
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 #else
     (void) vec_dot_q5_0_q8_1_mul_mat;
-    assert(false);
+    bad_arch();
 #endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
@@ -4157,7 +4147,7 @@ mul_mat_q5_1(
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 #else
     (void) vec_dot_q5_1_q8_1_mul_mat;
-    assert(false);
+    bad_arch();
 #endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
@@ -4224,7 +4214,7 @@ template <bool need_check> static __global__ void
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 #else
     (void) vec_dot_q8_0_q8_1_mul_mat;
-    assert(false);
+    bad_arch();
 #endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
@@ -4291,7 +4281,7 @@ mul_mat_q2_K(
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 #else
     (void) vec_dot_q2_K_q8_1_mul_mat;
-    assert(false);
+    bad_arch();
 #endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
@@ -4360,7 +4350,7 @@ template <bool need_check> static __global__ void
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 #else
     (void) vec_dot_q3_K_q8_1_mul_mat;
-    assert(false);
+    bad_arch();
 #endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
@@ -4429,7 +4419,7 @@ template <bool need_check> static __global__ void
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 #else
     (void) vec_dot_q4_K_q8_1_mul_mat;
-    assert(false);
+    bad_arch();
 #endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
@@ -4496,7 +4486,7 @@ mul_mat_q5_K(
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 #else
     (void) vec_dot_q5_K_q8_1_mul_mat;
-    assert(false);
+    bad_arch();
 #endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
@@ -4565,7 +4555,7 @@ template <bool need_check> static __global__ void
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 #else
     (void) vec_dot_q6_K_q8_1_mul_mat;
-    assert(false);
+    bad_arch();
 #endif // __CUDA_ARCH__ >= CC_VOLTA
 }