]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
HIP: fix flash_attn_stream_k_fixup warning (llama/11604)
authorJohannes Gäßler <redacted>
Sun, 2 Feb 2025 22:48:29 +0000 (23:48 +0100)
committerGeorgi Gerganov <redacted>
Mon, 3 Feb 2025 20:00:57 +0000 (22:00 +0200)
ggml/src/ggml-cuda/fattn-common.cuh
ggml/src/ggml-cuda/softmax.cu

index cfd7c0f4475dc2e1829697161de11f9d4ed254a9..d40ee2da4188799bc8d52c3d9a565ed5b7c0eebf 100644 (file)
@@ -516,6 +516,12 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
         nullptr;
 }
 
+// The HIP compiler for some reason complains that it can't unroll a loop because of the jt*ncols + j >= ne01 conditional.
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wpass-failed"
+#endif // __clang__
+
 template<int D, int ncols, int KQ_stride> // D == head size
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(D, 1)
@@ -614,6 +620,10 @@ static __global__ void flash_attn_stream_k_fixup(
     }
 }
 
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif // __clang__
+
 template<int D, int parallel_blocks> // D == head size
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(D, 1)
index da377200e01bb33cdcf7915aa4271b9ea142f288..aac6e0999880a529e63a8a796e1fbcbffa60e57f 100644 (file)
@@ -18,7 +18,7 @@ __device__ float __forceinline__ t2f32<half>(half val) {
 #ifdef __clang__
 #pragma clang diagnostic push
 #pragma clang diagnostic ignored "-Wpass-failed"
-#endif
+#endif // __clang__
 template <bool use_shared, int ncols_template, int block_size_template, typename T>
 static __global__ void soft_max_f32(
         const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
@@ -126,7 +126,7 @@ static __global__ void soft_max_f32(
 }
 #ifdef __clang__
 #pragma clang diagnostic pop
-#endif
+#endif // __clang__
 
 static __global__ void soft_max_back_f32(
         const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {