]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
HIP: remove the use of __HIP_PLATFORM_AMD__, explicitly support only AMD targets...
authoruvos <redacted>
Tue, 29 Jul 2025 18:23:04 +0000 (20:23 +0200)
committerGeorgi Gerganov <redacted>
Sat, 2 Aug 2025 14:51:21 +0000 (17:51 +0300)
src/ggml-cuda/common.cuh
src/ggml-cuda/fattn-common.cuh
src/ggml-cuda/fattn-mma-f16.cuh
src/ggml-cuda/fattn-tile-f16.cu
src/ggml-cuda/fattn-tile-f32.cu
src/ggml-cuda/fattn-wmma-f16.cu
src/ggml-cuda/ggml-cuda.cu
src/ggml-cuda/mma.cuh
src/ggml-cuda/mmq.cuh
src/ggml-cuda/vendors/hip.h

index 19fcc5982241b1640a74a9f15affe94162396675..44daafbf78ad89d8b67e04637eaa5acfe371ba75 100644 (file)
@@ -176,7 +176,7 @@ static const char * cu_get_error_str(CUresult err) {
 #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
 #endif
 
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
 #    define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes)                                                       \
         do {                                                                                                   \
             static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = { false };                         \
@@ -191,7 +191,7 @@ static const char * cu_get_error_str(CUresult err) {
         do {                                             \
             GGML_UNUSED(nbytes);                         \
         } while (0)
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+#endif // !(defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
 
 #if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
 #define GGML_CUDA_ASSUME(x) __builtin_assume(x)
@@ -211,9 +211,9 @@ typedef float2 dfloat2;
 #define GGML_USE_VMM
 #endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
 
-#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
+#if defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
 #define FP16_AVAILABLE
-#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
+#endif // defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
 
 #if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
 #define FAST_FP16_AVAILABLE
@@ -227,17 +227,17 @@ typedef float2 dfloat2;
 #define FP16_MMA_AVAILABLE
 #endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
 
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3) && !defined(GGML_HIP_NO_MMQ_MFMA)
+#if defined(GGML_USE_HIP) && defined(CDNA3) && !defined(GGML_HIP_NO_MMQ_MFMA)
 #define AMD_MFMA_AVAILABLE
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3)
+#endif // defined(GGML_USE_HIP) && defined(CDNA3) && !defined(GGML_HIP_NO_MMQ_MFMA)
 
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
+#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
 #define NEW_MMA_AVAILABLE
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
+#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
 
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 #define CP_ASYNC_AVAILABLE
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 
 #if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
 #define FLASH_ATTN_AVAILABLE
@@ -259,7 +259,7 @@ static bool fast_fp16_hardware_available(const int cc) {
 
 // Any FP16 tensor core instructions are available for ggml code.
 static bool fp16_mma_available(const int cc) {
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
+#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
     return false;
 #else
     if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
@@ -275,7 +275,7 @@ static bool fp16_mma_available(const int cc) {
     } else {
         return false;
     }
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
+#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
 }
 
 // To be used for feature selection of external libraries, e.g. cuBLAS.
@@ -312,25 +312,25 @@ static bool cp_async_available(const int cc) {
 }
 
 static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(__GFX9__) || defined(__GFX8__))
+#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
     return 64;
 #else
     return 32;
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(__GFX9__) || defined(__GFX8__))
+#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
 }
 
 [[noreturn]]
 static __device__ void no_device_code(
     const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
 
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#if defined(GGML_USE_HIP)
     printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
            file_name, line, function_name, arch);
     GGML_UNUSED(arch_list);
 #else
     printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
            file_name, line, function_name, arch, arch_list);
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#endif // defined(GGML_USE_HIP)
     __trap();
 
     GGML_UNUSED(no_device_code); // suppress unused function warning
@@ -367,7 +367,7 @@ struct ggml_cuda_unroll<1> {
 
 template<int width = WARP_SIZE>
 static __device__ __forceinline__ int warp_reduce_sum(int x) {
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
     return __reduce_add_sync(0xffffffff, x);
 #else
 #pragma unroll
@@ -375,7 +375,7 @@ static __device__ __forceinline__ int warp_reduce_sum(int x) {
         x += __shfl_xor_sync(0xffffffff, x, offset, width);
     }
     return x;
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 }
 
 template<int width = WARP_SIZE>
@@ -444,11 +444,11 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
 static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
 #ifdef FP16_AVAILABLE
 
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
+#if !defined(GGML_USE_HIP) && CUDART_VERSION < CUDART_HMAX
     return __float2half(fmaxf(__half2float(a), __half2float(b)));
 #else
     return __hmax(a, b);
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
+#endif // !defined(GGML_USE_HIP) && CUDART_VERSION < CUDART_HMAX
 
 #else
    NO_DEVICE_CODE;
@@ -476,7 +476,7 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal
 
 template<int width = WARP_SIZE>
 static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
+#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
 #pragma unroll
    for (int offset = width/2; offset > 0; offset >>= 1) {
        x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width));
@@ -485,7 +485,7 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
 #else
    GGML_UNUSED(x);
    NO_DEVICE_CODE;
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
+#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
 }
 
 #if CUDART_VERSION < CUDART_HMASK
@@ -497,7 +497,7 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
 #endif // CUDART_VERSION < CUDART_HMASK
 
 static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#if defined(GGML_USE_HIP)
 #if defined(CDNA) || defined(RDNA2) || defined(__gfx906__)
     c = __builtin_amdgcn_sdot4(a, b, c, false);
 #elif defined(RDNA3) || defined(RDNA4)
@@ -523,7 +523,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
 #endif
     return c;
 
-#else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#else // defined(GGML_USE_HIP)
 
 #if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
     return __dp4a(a, b, c);
@@ -533,7 +533,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
     return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
 
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#endif // defined(GGML_USE_HIP)
 }
 
 typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
index 95e704e393c2a60b975c97549936b1a5b9698fec..0cc74f284a15b8f921b31104e7ac195319258f7a 100644 (file)
@@ -592,9 +592,9 @@ static __global__ void flash_attn_stream_k_fixup(
 }
 
 template<int D> // D == head size
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#if !defined(GGML_USE_HIP)
 __launch_bounds__(D, 1)
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#endif // !(defined(GGML_USE_HIP)
 static __global__ void flash_attn_combine_results(
         const float  * __restrict__ VKQ_parts,
         const float2 * __restrict__ VKQ_meta,
index 83cf872f68a7b1434df0319ca6e0f9cfca542ef5..8e847d361b455b886c60b3c2688797d08c835bc8 100644 (file)
@@ -1391,24 +1391,24 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
         constexpr bool use_logit_softcap = false;
         fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
 
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
         static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
         if (!shared_memory_limit_raised[id]) {
             CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
             shared_memory_limit_raised[id] = true;
         }
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
     } else {
         constexpr bool use_logit_softcap = true;
         fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
 
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
         static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
         if (!shared_memory_limit_raised[id]) {
             CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
             shared_memory_limit_raised[id] = true;
         }
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
     }
 
     launch_fattn<DV, ncols1, ncols2>
index 7661c21efbbdd9840fd9f2f851eba09e1e09c553..678288c13e3e87e92053dfda0b25af9ce525ab6d 100644 (file)
@@ -5,9 +5,9 @@
 #define FATTN_KQ_STRIDE_TILE_F16 64
 
 template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#if !defined(GGML_USE_HIP)
 __launch_bounds__(nwarps*WARP_SIZE, 2)
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#endif // !defined(GGML_USE_HIP)
 static __global__ void flash_attn_tile_ext_f16(
         const char * __restrict__ Q,
         const char * __restrict__ K,
index 11778bb9611d661defac4b59abf6c01ce544a494..bc283d9a7e4d74db95b0f86db63333f7fda90d88 100644 (file)
@@ -5,9 +5,9 @@
 #define FATTN_KQ_STRIDE_TILE_F32 32
 
 template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#if !defined(GGML_USE_HIP)
 __launch_bounds__(nwarps*WARP_SIZE, 2)
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#endif // !defined(GGML_USE_HIP)
 static __global__ void flash_attn_tile_ext_f32(
         const char * __restrict__ Q,
         const char * __restrict__ K,
index c9b083bed014bb717cc126eb0f4ccde125dea3e3..f4393ec5711d925e95a434ecf6b2a9bd026efd67 100644 (file)
@@ -7,7 +7,7 @@
 #include "fattn-wmma-f16.cuh"
 
 #ifdef FP16_MMA_AVAILABLE
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#if !defined(GGML_USE_HIP)
 #include <mma.h>
 #ifdef GGML_USE_MUSA
 namespace wmma = mtmusa::wmma;
@@ -18,7 +18,7 @@ namespace wmma = nvcuda::wmma;
 #undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
 #include <rocwmma/rocwmma.hpp>
 namespace wmma = rocwmma;
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#endif // !defined(GGML_USE_HIP)
 #endif // FP16_MMA_AVAILABLE
 
 // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
@@ -546,7 +546,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
         return;
     }
 
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#if !defined(GGML_USE_HIP)
     if (Q->ne[1] <= 8 && Q->ne[0] % warp_size == 0) {
         constexpr int cols_per_block = 8;
         switch (Q->ne[0]) {
@@ -568,7 +568,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
         }
         return;
     }
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#endif // !defined(GGML_USE_HIP)
 
     if (Q->ne[1] <= 32) {
         constexpr int cols_per_block = 16;
index 819c2be78ec12fabe5cb6bfadecb0a7fc0ae0e9d..51792794673bbe0216c3522f18864ef7663d6d93 100644 (file)
@@ -128,7 +128,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
     return err;
 }
 
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#if defined(GGML_USE_HIP)
 static int ggml_cuda_parse_id(char devName[]) {
     // A list of possible Target IDs can be found under the rocclr/clr repo in device.cpp
     // these values are not stable so this is susceptible to breakage
@@ -175,10 +175,10 @@ static int ggml_cuda_parse_id(char devName[]) {
     archNum += archMinor;
     return archNum;
 }
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#endif // defined(GGML_USE_HIP)
 
 static ggml_cuda_device_info ggml_cuda_init() {
-#ifdef __HIP_PLATFORM_AMD__
+#if defined(GGML_USE_HIP)
     // Workaround for a rocBLAS bug when using multiple graphics cards:
     // https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346
     {
@@ -251,7 +251,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
         info.devices[id].nsm        = prop.multiProcessorCount;
         info.devices[id].smpb       = prop.sharedMemPerBlock;
         info.devices[id].warp_size  = prop.warpSize;
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#if defined(GGML_USE_HIP)
         info.devices[id].smpbo = prop.sharedMemPerBlock;
 
         info.devices[id].cc = ggml_cuda_parse_id(prop.gcnArchName);
@@ -281,7 +281,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
         info.devices[id].cc = 100*prop.major + 10*prop.minor;
         GGML_LOG_INFO("  Device %d: %s, compute capability %d.%d, VMM: %s\n",
                         id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#endif // defined(GGML_USE_HIP)
     }
 
     for (int id = 0; id < info.device_count; ++id) {
index d6817d804d2f3784a73dd83e1d37bd4e8f044b03..a86365c6a061cd7d8f5f5ac9cd10944158730301 100644 (file)
@@ -68,7 +68,7 @@ namespace ggml_cuda_mma {
         static constexpr int I  = I_;
         static constexpr int J  = J_;
 
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#if defined(GGML_USE_HIP)
         static constexpr int ne = I * J / 64;
         T x[ne] = {0};
 
@@ -132,7 +132,7 @@ namespace ggml_cuda_mma {
                 static_assert(I == -1 && J == -1, "template specialization not implemented");
             }
         }
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#endif // defined(GGML_USE_HIP)
     };
 
     template <int I_, int J_>
index 36e84be154e3c6a74c9f8b231ac86f7b871d22bb..d8650dd70b91515c8120b99ccd37fb617d956750 100644 (file)
@@ -104,9 +104,9 @@ static constexpr __device__ int get_mmq_x_max_device() {
     return 128;
 #else // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
 
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#if defined(GGML_USE_HIP)
     return 64;
-#else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#else // defined(GGML_USE_HIP)
 
 #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 #ifdef GGML_CUDA_FORCE_MMQ
@@ -118,7 +118,7 @@ static constexpr __device__ int get_mmq_x_max_device() {
     return 64;
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#endif // defined(GGML_USE_HIP)
 #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
 }
 
@@ -128,7 +128,7 @@ static int get_mmq_y_host(const int cc) {
 }
 
 static constexpr __device__ int get_mmq_y_device() {
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#if defined(GGML_USE_HIP)
 #if defined(RDNA1)
     return 64;
 #else
@@ -140,7 +140,7 @@ static constexpr __device__ int get_mmq_y_device() {
 #else
     return 64;
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#endif // defined(GGML_USE_HIP)
 }
 
 // Decouple shared memory tile sizes from WARP_SIZE to allow for different warp sizes.
@@ -250,7 +250,7 @@ static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/)
 }
 #endif // AMD_MFMA_AVAILABLE
 
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#if defined(GGML_USE_HIP)
 static int mmq_get_nwarps_host(const int cc) {
     return amd_mfma_available(cc) ? 8 : 4;
 }
@@ -258,10 +258,10 @@ static int mmq_get_nwarps_host(const int cc) {
 static int mmq_get_nwarps_host(const int /*cc*/) {
     return 8;
 }
-#endif // (GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#endif // (GGML_USE_HIP)
 
 static constexpr __device__ int mmq_get_nwarps_device() {
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#if defined(GGML_USE_HIP)
 #if defined(AMD_MFMA_AVAILABLE)
     return 8;
 #else
@@ -269,7 +269,7 @@ static constexpr __device__ int mmq_get_nwarps_device() {
 #endif // AMD_MFMA_AVAILABLE
 #else
     return 8;
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#endif // defined(GGML_USE_HIP)
 }
 
 // ------------------------------------------------------------
@@ -3047,7 +3047,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
 // The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
 
 template <ggml_type type, int mmq_x, bool need_check>
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#if defined(GGML_USE_HIP)
 #if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
     __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
 #endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
@@ -3057,7 +3057,7 @@ template <ggml_type type, int mmq_x, bool need_check>
 #else
     __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#endif // defined(GGML_USE_HIP)
 static __global__ void mul_mat_q(
         const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
         const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
@@ -3097,7 +3097,7 @@ static __global__ void mul_mat_q(
     __syncthreads();
 
     // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
-#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
+#if (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
     {
         const int wt = blockIdx.z / nchannels_y;
         const int zt = blockIdx.z - wt*nchannels_y;
@@ -3151,7 +3151,7 @@ static __global__ void mul_mat_q(
              tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
         return;
     }
-#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
+#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
 
     const     int64_t blocks_per_ne00 = ncols_x / qk;
     constexpr int     blocks_per_iter = MMQ_ITER_K / qk;
index 56e59a058f999f228e6a75a2376bdd9195800a11..8b172e60f4b7e770d62eb5fb5a7fc8006819a3c3 100644 (file)
@@ -5,10 +5,8 @@
 #include <hipblas/hipblas.h>
 #include <hip/hip_fp16.h>
 #include <hip/hip_bfloat16.h>
-#ifdef __HIP_PLATFORM_AMD__
 // for rocblas_initialize()
 #include "rocblas/rocblas.h"
-#endif // __HIP_PLATFORM_AMD__
 
 #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
 #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
 #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
 #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
 
-#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION >= 70000000
+#if HIP_VERSION >= 70000000
 #define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F
 #define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F
 #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F
 #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
 #define cublasComputeType_t hipblasDatatype_t
 #define cudaDataType_t hipblasDatatype_t
-#endif
+#endif // HIP_VERSION >= 7000000
+
+#if !defined(__HIP_PLATFORM_AMD__)
+#error "The HIP backend supports only AMD targets"
+#endif // !defined(__HIP_PLATFORM_AMD__)
 
 #define __CUDA_ARCH__ 1300
 
@@ -249,7 +251,7 @@ static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigne
     return c;
 }
 
-#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
+#if HIP_VERSION < 50600000
 // __shfl_xor() for half2 was added in ROCm 5.6
 static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) {
     typedef union half2_b32 {
@@ -261,4 +263,4 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int
     tmp.b32 = __shfl_xor(tmp.b32, laneMask, width);
     return tmp.val;
 }
-#endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
+#endif // HIP_VERSION < 50600000