]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: mul_mat_q RDNA2 tunings (#2910)
authorJohannes Gäßler <redacted>
Wed, 13 Sep 2023 09:20:24 +0000 (11:20 +0200)
committerGitHub <redacted>
Wed, 13 Sep 2023 09:20:24 +0000 (11:20 +0200)
* CUDA: mul_mat_q RDNA2 tunings

* Update ggml-cuda.cu

Co-authored-by: Henri Vasserman <redacted>
---------

Co-authored-by: Henri Vasserman <redacted>
CMakeLists.txt
Makefile
ggml-cuda.cu

index 4f7b05fc2bfc54ed5a4cc945ef52dbbdfc27478b..12adaf328a355218e5650b1355b70514af1dfb0e 100644 (file)
@@ -388,7 +388,6 @@ if (LLAMA_HIPBLAS)
         target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
         target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
         target_compile_definitions(ggml-rocm PRIVATE K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
-        target_compile_definitions(ggml-rocm PRIVATE CC_TURING=1000000000)
         set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX)
         target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas)
 
index a774dc50f372d6f2738684cf72e0aad6615514a5..5b65dd1f2333b807acec7d9ceb311747e16ec414 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -408,7 +408,6 @@ ifdef LLAMA_HIPBLAS
        HIPFLAGS    += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
        HIPFLAGS    += -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y)
        HIPFLAGS    += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER)
-       HIPFLAGS    += -DCC_TURING=1000000000
 ifdef LLAMA_CUDA_FORCE_DMMV
        HIPFLAGS        += -DGGML_CUDA_FORCE_DMMV
 endif # LLAMA_CUDA_FORCE_DMMV
index 1d8bc2699c84f071516ad74a0bc37eb1f74a400b..fe7332b2a3580cb6a6b8c5f1ff04e3c8cc8e0979 100644 (file)
@@ -13,7 +13,7 @@
 #ifdef __HIP_PLATFORM_AMD__
 // for rocblas_initialize()
 #include "rocblas/rocblas.h"
-#endif
+#endif // __HIP_PLATFORM_AMD__
 #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
 #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
 #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
 #include <cuda_runtime.h>
 #include <cublas_v2.h>
 #include <cuda_fp16.h>
-#endif
+#endif // defined(GGML_USE_HIPBLAS)
 
 #include "ggml-cuda.h"
 #include "ggml.h"
 
-#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
-#ifndef CC_TURING
-#define CC_TURING   700
-#endif
+#define MIN_CC_DP4A   610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
+#define CC_TURING     700
+#define CC_OFFSET_AMD 1000000
+#define CC_RDNA2      CC_OFFSET_AMD + 1030
 
 #if defined(GGML_USE_HIPBLAS)
 #define __CUDA_ARCH__ 1300
 
+#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
+    defined(__gfx1150__) || defined(__gfx1151__)
+#define RDNA3
+#endif
+
+#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
+    defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
+#define RDNA2
+#endif
+
 #ifndef __has_builtin
     #define __has_builtin(x) 0
 #endif
@@ -132,7 +142,7 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
 #endif
     return c;
 }
-#endif
+#endif // defined(GGML_USE_HIPBLAS)
 
 #if defined(_MSC_VER)
 #pragma warning(disable: 4244 4267) // possible loss of data
@@ -3472,6 +3482,12 @@ static __device__ __forceinline__ void mul_mat_q(
     }
 }
 
+#define  MMQ_X_Q4_0_RDNA2  64
+#define  MMQ_Y_Q4_0_RDNA2  128
+#define NWARPS_Q4_0_RDNA2  8
+#define  MMQ_X_Q4_0_RDNA1  64
+#define  MMQ_Y_Q4_0_RDNA1  64
+#define NWARPS_Q4_0_RDNA1  8
 #define  MMQ_X_Q4_0_AMPERE 64
 #define  MMQ_Y_Q4_0_AMPERE 128
 #define NWARPS_Q4_0_AMPERE 4
@@ -3479,11 +3495,32 @@ static __device__ __forceinline__ void mul_mat_q(
 #define  MMQ_Y_Q4_0_PASCAL 64
 #define NWARPS_Q4_0_PASCAL 8
 
-template <bool need_check> static __global__ void mul_mat_q4_0(
+template <bool need_check> static __global__ void
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    __launch_bounds__(WARP_SIZE*NWARPS_Q4_0_RDNA2, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+    mul_mat_q4_0(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if __CUDA_ARCH__ >= CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    const int mmq_x  =  MMQ_X_Q4_0_RDNA2;
+    const int mmq_y  =  MMQ_Y_Q4_0_RDNA2;
+    const int nwarps = NWARPS_Q4_0_RDNA2;
+#else
+    const int mmq_x  =  MMQ_X_Q4_0_RDNA1;
+    const int mmq_y  =  MMQ_Y_Q4_0_RDNA1;
+    const int nwarps = NWARPS_Q4_0_RDNA1;
+#endif // defined(RDNA3) || defined(RDNA2)
+
+    mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
+        load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= CC_TURING
     const int mmq_x  =  MMQ_X_Q4_0_AMPERE;
     const int mmq_y  =  MMQ_Y_Q4_0_AMPERE;
     const int nwarps = NWARPS_Q4_0_AMPERE;
@@ -3506,6 +3543,12 @@ template <bool need_check> static __global__ void mul_mat_q4_0(
 #endif // __CUDA_ARCH__ >= CC_TURING
 }
 
+#define  MMQ_X_Q4_1_RDNA2  64
+#define  MMQ_Y_Q4_1_RDNA2  128
+#define NWARPS_Q4_1_RDNA2  8
+#define  MMQ_X_Q4_1_RDNA1  64
+#define  MMQ_Y_Q4_1_RDNA1  64
+#define NWARPS_Q4_1_RDNA1  8
 #define  MMQ_X_Q4_1_AMPERE 64
 #define  MMQ_Y_Q4_1_AMPERE 128
 #define NWARPS_Q4_1_AMPERE 4
@@ -3514,14 +3557,33 @@ template <bool need_check> static __global__ void mul_mat_q4_0(
 #define NWARPS_Q4_1_PASCAL 8
 
 template <bool need_check> static __global__ void
-#if __CUDA_ARCH__ < CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    __launch_bounds__(WARP_SIZE*NWARPS_Q4_1_RDNA2, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#elif __CUDA_ARCH__ < CC_TURING
     __launch_bounds__(WARP_SIZE*NWARPS_Q4_1_PASCAL, 2)
 #endif // __CUDA_ARCH__ < CC_TURING
     mul_mat_q4_1(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if __CUDA_ARCH__ >= CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    const int mmq_x  =  MMQ_X_Q4_1_RDNA2;
+    const int mmq_y  =  MMQ_Y_Q4_1_RDNA2;
+    const int nwarps = NWARPS_Q4_1_RDNA2;
+#else
+    const int mmq_x  =  MMQ_X_Q4_1_RDNA1;
+    const int mmq_y  =  MMQ_Y_Q4_1_RDNA1;
+    const int nwarps = NWARPS_Q4_1_RDNA1;
+#endif // defined(RDNA3) || defined(RDNA2)
+
+    mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
+        load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= CC_TURING
     const int mmq_x  =  MMQ_X_Q4_1_AMPERE;
     const int mmq_y  =  MMQ_Y_Q4_1_AMPERE;
     const int nwarps = NWARPS_Q4_1_AMPERE;
@@ -3544,6 +3606,12 @@ template <bool need_check> static __global__ void
 #endif // __CUDA_ARCH__ >= CC_TURING
 }
 
+#define  MMQ_X_Q5_0_RDNA2  64
+#define  MMQ_Y_Q5_0_RDNA2  128
+#define NWARPS_Q5_0_RDNA2  8
+#define  MMQ_X_Q5_0_RDNA1  64
+#define  MMQ_Y_Q5_0_RDNA1  64
+#define NWARPS_Q5_0_RDNA1  8
 #define  MMQ_X_Q5_0_AMPERE 128
 #define  MMQ_Y_Q5_0_AMPERE 64
 #define NWARPS_Q5_0_AMPERE 4
@@ -3551,11 +3619,32 @@ template <bool need_check> static __global__ void
 #define  MMQ_Y_Q5_0_PASCAL 64
 #define NWARPS_Q5_0_PASCAL 8
 
-template <bool need_check> static __global__ void mul_mat_q5_0(
+template <bool need_check> static __global__ void
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    __launch_bounds__(WARP_SIZE*NWARPS_Q5_0_RDNA2, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+    mul_mat_q5_0(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if __CUDA_ARCH__ >= CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    const int mmq_x  =  MMQ_X_Q5_0_RDNA2;
+    const int mmq_y  =  MMQ_Y_Q5_0_RDNA2;
+    const int nwarps = NWARPS_Q5_0_RDNA2;
+#else
+    const int mmq_x  =  MMQ_X_Q5_0_RDNA1;
+    const int mmq_y  =  MMQ_Y_Q5_0_RDNA1;
+    const int nwarps = NWARPS_Q5_0_RDNA1;
+#endif // defined(RDNA3) || defined(RDNA2)
+
+    mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
+        load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= CC_TURING
     const int mmq_x  =  MMQ_X_Q5_0_AMPERE;
     const int mmq_y  =  MMQ_Y_Q5_0_AMPERE;
     const int nwarps = NWARPS_Q5_0_AMPERE;
@@ -3578,6 +3667,12 @@ template <bool need_check> static __global__ void mul_mat_q5_0(
 #endif // __CUDA_ARCH__ >= CC_TURING
 }
 
+#define  MMQ_X_Q5_1_RDNA2  64
+#define  MMQ_Y_Q5_1_RDNA2  128
+#define NWARPS_Q5_1_RDNA2  8
+#define  MMQ_X_Q5_1_RDNA1  64
+#define  MMQ_Y_Q5_1_RDNA1  64
+#define NWARPS_Q5_1_RDNA1  8
 #define  MMQ_X_Q5_1_AMPERE 128
 #define  MMQ_Y_Q5_1_AMPERE 64
 #define NWARPS_Q5_1_AMPERE 4
@@ -3585,11 +3680,32 @@ template <bool need_check> static __global__ void mul_mat_q5_0(
 #define  MMQ_Y_Q5_1_PASCAL 64
 #define NWARPS_Q5_1_PASCAL 8
 
-template <bool need_check> static __global__ void mul_mat_q5_1(
+template <bool need_check> static __global__ void
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    __launch_bounds__(WARP_SIZE*NWARPS_Q5_1_RDNA2, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+mul_mat_q5_1(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if __CUDA_ARCH__ >= CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    const int mmq_x  =  MMQ_X_Q5_1_RDNA2;
+    const int mmq_y  =  MMQ_Y_Q5_1_RDNA2;
+    const int nwarps = NWARPS_Q5_1_RDNA2;
+#else
+    const int mmq_x  =  MMQ_X_Q5_1_RDNA1;
+    const int mmq_y  =  MMQ_Y_Q5_1_RDNA1;
+    const int nwarps = NWARPS_Q5_1_RDNA1;
+#endif // defined(RDNA3) || defined(RDNA2)
+
+    mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
+        load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= CC_TURING
     const int mmq_x  =  MMQ_X_Q5_1_AMPERE;
     const int mmq_y  =  MMQ_Y_Q5_1_AMPERE;
     const int nwarps = NWARPS_Q5_1_AMPERE;
@@ -3612,6 +3728,12 @@ template <bool need_check> static __global__ void mul_mat_q5_1(
 #endif // __CUDA_ARCH__ >= CC_TURING
 }
 
+#define  MMQ_X_Q8_0_RDNA2  64
+#define  MMQ_Y_Q8_0_RDNA2  128
+#define NWARPS_Q8_0_RDNA2  8
+#define  MMQ_X_Q8_0_RDNA1  64
+#define  MMQ_Y_Q8_0_RDNA1  64
+#define NWARPS_Q8_0_RDNA1  8
 #define  MMQ_X_Q8_0_AMPERE 128
 #define  MMQ_Y_Q8_0_AMPERE 64
 #define NWARPS_Q8_0_AMPERE 4
@@ -3619,11 +3741,32 @@ template <bool need_check> static __global__ void mul_mat_q5_1(
 #define  MMQ_Y_Q8_0_PASCAL 64
 #define NWARPS_Q8_0_PASCAL 8
 
-template <bool need_check> static __global__ void mul_mat_q8_0(
+template <bool need_check> static __global__ void
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    __launch_bounds__(WARP_SIZE*NWARPS_Q8_0_RDNA2, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+    mul_mat_q8_0(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if __CUDA_ARCH__ >= CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    const int mmq_x  =  MMQ_X_Q8_0_RDNA2;
+    const int mmq_y  =  MMQ_Y_Q8_0_RDNA2;
+    const int nwarps = NWARPS_Q8_0_RDNA2;
+#else
+    const int mmq_x  =  MMQ_X_Q8_0_RDNA1;
+    const int mmq_y  =  MMQ_Y_Q8_0_RDNA1;
+    const int nwarps = NWARPS_Q8_0_RDNA1;
+#endif // defined(RDNA3) || defined(RDNA2)
+
+    mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
+        load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= CC_TURING
     const int mmq_x  =  MMQ_X_Q8_0_AMPERE;
     const int mmq_y  =  MMQ_Y_Q8_0_AMPERE;
     const int nwarps = NWARPS_Q8_0_AMPERE;
@@ -3646,6 +3789,12 @@ template <bool need_check> static __global__ void mul_mat_q8_0(
 #endif // __CUDA_ARCH__ >= CC_TURING
 }
 
+#define  MMQ_X_Q2_K_RDNA2  64
+#define  MMQ_Y_Q2_K_RDNA2  128
+#define NWARPS_Q2_K_RDNA2  8
+#define  MMQ_X_Q2_K_RDNA1  128
+#define  MMQ_Y_Q2_K_RDNA1  32
+#define NWARPS_Q2_K_RDNA1  8
 #define  MMQ_X_Q2_K_AMPERE 64
 #define  MMQ_Y_Q2_K_AMPERE 128
 #define NWARPS_Q2_K_AMPERE 4
@@ -3653,11 +3802,32 @@ template <bool need_check> static __global__ void mul_mat_q8_0(
 #define  MMQ_Y_Q2_K_PASCAL 64
 #define NWARPS_Q2_K_PASCAL 8
 
-template <bool need_check> static __global__ void mul_mat_q2_K(
+template <bool need_check> static __global__ void
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    __launch_bounds__(WARP_SIZE*NWARPS_Q2_K_RDNA2, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+mul_mat_q2_K(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if __CUDA_ARCH__ >= CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    const int mmq_x  =  MMQ_X_Q2_K_RDNA2;
+    const int mmq_y  =  MMQ_Y_Q2_K_RDNA2;
+    const int nwarps = NWARPS_Q2_K_RDNA2;
+#else
+    const int mmq_x  =  MMQ_X_Q2_K_RDNA1;
+    const int mmq_y  =  MMQ_Y_Q2_K_RDNA1;
+    const int nwarps = NWARPS_Q2_K_RDNA1;
+#endif // defined(RDNA3) || defined(RDNA2)
+
+    mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
+        load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= CC_TURING
     const int mmq_x  =  MMQ_X_Q2_K_AMPERE;
     const int mmq_y  =  MMQ_Y_Q2_K_AMPERE;
     const int nwarps = NWARPS_Q2_K_AMPERE;
@@ -3680,6 +3850,12 @@ template <bool need_check> static __global__ void mul_mat_q2_K(
 #endif // __CUDA_ARCH__ >= CC_TURING
 }
 
+#define  MMQ_X_Q3_K_RDNA2  128
+#define  MMQ_Y_Q3_K_RDNA2  64
+#define NWARPS_Q3_K_RDNA2  8
+#define  MMQ_X_Q3_K_RDNA1  32
+#define  MMQ_Y_Q3_K_RDNA1  128
+#define NWARPS_Q3_K_RDNA1  8
 #define  MMQ_X_Q3_K_AMPERE 128
 #define  MMQ_Y_Q3_K_AMPERE 128
 #define NWARPS_Q3_K_AMPERE 4
@@ -3688,14 +3864,33 @@ template <bool need_check> static __global__ void mul_mat_q2_K(
 #define NWARPS_Q3_K_PASCAL 8
 
 template <bool need_check> static __global__ void
-#if __CUDA_ARCH__ < CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    __launch_bounds__(WARP_SIZE*NWARPS_Q3_K_RDNA2, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#elif __CUDA_ARCH__ < CC_TURING
     __launch_bounds__(WARP_SIZE*NWARPS_Q3_K_PASCAL, 2)
 #endif // __CUDA_ARCH__ < CC_TURING
     mul_mat_q3_K(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if __CUDA_ARCH__ >= CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    const int mmq_x  =  MMQ_X_Q3_K_RDNA2;
+    const int mmq_y  =  MMQ_Y_Q3_K_RDNA2;
+    const int nwarps = NWARPS_Q3_K_RDNA2;
+#else
+    const int mmq_x  =  MMQ_X_Q3_K_RDNA1;
+    const int mmq_y  =  MMQ_Y_Q3_K_RDNA1;
+    const int nwarps = NWARPS_Q3_K_RDNA1;
+#endif // defined(RDNA3) || defined(RDNA2)
+
+    mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
+        load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= CC_TURING
     const int mmq_x  =  MMQ_X_Q3_K_AMPERE;
     const int mmq_y  =  MMQ_Y_Q3_K_AMPERE;
     const int nwarps = NWARPS_Q3_K_AMPERE;
@@ -3718,6 +3913,12 @@ template <bool need_check> static __global__ void
 #endif // __CUDA_ARCH__ >= CC_TURING
 }
 
+#define  MMQ_X_Q4_K_RDNA2  64
+#define  MMQ_Y_Q4_K_RDNA2  128
+#define NWARPS_Q4_K_RDNA2  8
+#define  MMQ_X_Q4_K_RDNA1  32
+#define  MMQ_Y_Q4_K_RDNA1  64
+#define NWARPS_Q4_K_RDNA1  8
 #define  MMQ_X_Q4_K_AMPERE 64
 #define  MMQ_Y_Q4_K_AMPERE 128
 #define NWARPS_Q4_K_AMPERE 4
@@ -3726,14 +3927,33 @@ template <bool need_check> static __global__ void
 #define NWARPS_Q4_K_PASCAL 8
 
 template <bool need_check> static __global__ void
-#if __CUDA_ARCH__ < CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    __launch_bounds__(WARP_SIZE*NWARPS_Q4_K_RDNA2, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#elif __CUDA_ARCH__ < CC_TURING
     __launch_bounds__(WARP_SIZE*NWARPS_Q4_K_PASCAL, 2)
 #endif // __CUDA_ARCH__ < CC_TURING
     mul_mat_q4_K(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if __CUDA_ARCH__ >= CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    const int mmq_x  =  MMQ_X_Q4_K_RDNA2;
+    const int mmq_y  =  MMQ_Y_Q4_K_RDNA2;
+    const int nwarps = NWARPS_Q4_K_RDNA2;
+#else
+    const int mmq_x  =  MMQ_X_Q4_K_RDNA1;
+    const int mmq_y  =  MMQ_Y_Q4_K_RDNA1;
+    const int nwarps = NWARPS_Q4_K_RDNA1;
+#endif // defined(RDNA3) || defined(RDNA2)
+
+    mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
+        load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= CC_TURING
     const int mmq_x  =  MMQ_X_Q4_K_AMPERE;
     const int mmq_y  =  MMQ_Y_Q4_K_AMPERE;
     const int nwarps = NWARPS_Q4_K_AMPERE;
@@ -3756,6 +3976,12 @@ template <bool need_check> static __global__ void
 #endif // __CUDA_ARCH__ >= CC_TURING
 }
 
+#define  MMQ_X_Q5_K_RDNA2  64
+#define  MMQ_Y_Q5_K_RDNA2  128
+#define NWARPS_Q5_K_RDNA2  8
+#define  MMQ_X_Q5_K_RDNA1  32
+#define  MMQ_Y_Q5_K_RDNA1  64
+#define NWARPS_Q5_K_RDNA1  8
 #define  MMQ_X_Q5_K_AMPERE 64
 #define  MMQ_Y_Q5_K_AMPERE 128
 #define NWARPS_Q5_K_AMPERE 4
@@ -3763,11 +3989,32 @@ template <bool need_check> static __global__ void
 #define  MMQ_Y_Q5_K_PASCAL 64
 #define NWARPS_Q5_K_PASCAL 8
 
-template <bool need_check> static __global__ void mul_mat_q5_K(
+template <bool need_check> static __global__ void
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    __launch_bounds__(WARP_SIZE*NWARPS_Q5_K_RDNA2, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+mul_mat_q5_K(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if __CUDA_ARCH__ >= CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    const int mmq_x  =  MMQ_X_Q5_K_RDNA2;
+    const int mmq_y  =  MMQ_Y_Q5_K_RDNA2;
+    const int nwarps = NWARPS_Q5_K_RDNA2;
+#else
+    const int mmq_x  =  MMQ_X_Q5_K_RDNA1;
+    const int mmq_y  =  MMQ_Y_Q5_K_RDNA1;
+    const int nwarps = NWARPS_Q5_K_RDNA1;
+#endif // defined(RDNA3) || defined(RDNA2)
+
+    mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
+        load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= CC_TURING
     const int mmq_x  =  MMQ_X_Q5_K_AMPERE;
     const int mmq_y  =  MMQ_Y_Q5_K_AMPERE;
     const int nwarps = NWARPS_Q5_K_AMPERE;
@@ -3790,6 +4037,12 @@ template <bool need_check> static __global__ void mul_mat_q5_K(
 #endif // __CUDA_ARCH__ >= CC_TURING
 }
 
+#define  MMQ_X_Q6_K_RDNA2  64
+#define  MMQ_Y_Q6_K_RDNA2  128
+#define NWARPS_Q6_K_RDNA2  8
+#define  MMQ_X_Q6_K_RDNA1  32
+#define  MMQ_Y_Q6_K_RDNA1  64
+#define NWARPS_Q6_K_RDNA1  8
 #define  MMQ_X_Q6_K_AMPERE 64
 #define  MMQ_Y_Q6_K_AMPERE 64
 #define NWARPS_Q6_K_AMPERE 4
@@ -3798,14 +4051,33 @@ template <bool need_check> static __global__ void mul_mat_q5_K(
 #define NWARPS_Q6_K_PASCAL 8
 
 template <bool need_check> static __global__ void
-#if __CUDA_ARCH__ < CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    __launch_bounds__(WARP_SIZE*NWARPS_Q6_K_RDNA2, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#elif __CUDA_ARCH__ < CC_TURING
     __launch_bounds__(WARP_SIZE*NWARPS_Q6_K_PASCAL, 2)
 #endif // __CUDA_ARCH__ < CC_TURING
     mul_mat_q6_K(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if __CUDA_ARCH__ >= CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    const int mmq_x  =  MMQ_X_Q6_K_RDNA2;
+    const int mmq_y  =  MMQ_Y_Q6_K_RDNA2;
+    const int nwarps = NWARPS_Q6_K_RDNA2;
+#else
+    const int mmq_x  =  MMQ_X_Q6_K_RDNA1;
+    const int mmq_y  =  MMQ_Y_Q6_K_RDNA1;
+    const int nwarps = NWARPS_Q6_K_RDNA1;
+#endif // defined(RDNA3) || defined(RDNA2)
+
+    mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
+        load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= CC_TURING
     const int mmq_x  =  MMQ_X_Q6_K_AMPERE;
     const int mmq_y  =  MMQ_Y_Q6_K_AMPERE;
     const int nwarps = NWARPS_Q6_K_AMPERE;
@@ -4588,7 +4860,15 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
     const int compute_capability = g_compute_capabilities[id];
 
     int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_TURING) {
+    if (compute_capability >= CC_RDNA2) {
+        mmq_x  =  MMQ_X_Q4_0_RDNA2;
+        mmq_y  =  MMQ_Y_Q4_0_RDNA2;
+        nwarps = NWARPS_Q4_0_RDNA2;
+    } else if (compute_capability >= CC_OFFSET_AMD) {
+        mmq_x  =  MMQ_X_Q4_0_RDNA1;
+        mmq_y  =  MMQ_Y_Q4_0_RDNA1;
+        nwarps = NWARPS_Q4_0_RDNA1;
+    } else if (compute_capability >= CC_TURING) {
         mmq_x  =  MMQ_X_Q4_0_AMPERE;
         mmq_y  =  MMQ_Y_Q4_0_AMPERE;
         nwarps = NWARPS_Q4_0_AMPERE;
@@ -4625,7 +4905,15 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
     const int compute_capability = g_compute_capabilities[id];
 
     int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_TURING) {
+    if (compute_capability >= CC_RDNA2) {
+        mmq_x  =  MMQ_X_Q4_1_RDNA2;
+        mmq_y  =  MMQ_Y_Q4_1_RDNA2;
+        nwarps = NWARPS_Q4_1_RDNA2;
+    } else if (compute_capability >= CC_OFFSET_AMD) {
+        mmq_x  =  MMQ_X_Q4_1_RDNA1;
+        mmq_y  =  MMQ_Y_Q4_1_RDNA1;
+        nwarps = NWARPS_Q4_1_RDNA1;
+    } else if (compute_capability >= CC_TURING) {
         mmq_x  =  MMQ_X_Q4_1_AMPERE;
         mmq_y  =  MMQ_Y_Q4_1_AMPERE;
         nwarps = NWARPS_Q4_1_AMPERE;
@@ -4662,7 +4950,15 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
     const int compute_capability = g_compute_capabilities[id];
 
     int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_TURING) {
+    if (compute_capability >= CC_RDNA2) {
+        mmq_x  =  MMQ_X_Q5_0_RDNA2;
+        mmq_y  =  MMQ_Y_Q5_0_RDNA2;
+        nwarps = NWARPS_Q5_0_RDNA2;
+    } else if (compute_capability >= CC_OFFSET_AMD) {
+        mmq_x  =  MMQ_X_Q5_0_RDNA1;
+        mmq_y  =  MMQ_Y_Q5_0_RDNA1;
+        nwarps = NWARPS_Q5_0_RDNA1;
+    } else if (compute_capability >= CC_TURING) {
         mmq_x  =  MMQ_X_Q5_0_AMPERE;
         mmq_y  =  MMQ_Y_Q5_0_AMPERE;
         nwarps = NWARPS_Q5_0_AMPERE;
@@ -4699,7 +4995,15 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
     const int compute_capability = g_compute_capabilities[id];
 
     int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_TURING) {
+    if (compute_capability >= CC_RDNA2) {
+        mmq_x  =  MMQ_X_Q5_1_RDNA2;
+        mmq_y  =  MMQ_Y_Q5_1_RDNA2;
+        nwarps = NWARPS_Q5_1_RDNA2;
+    } else if (compute_capability >= CC_OFFSET_AMD) {
+        mmq_x  =  MMQ_X_Q5_1_RDNA1;
+        mmq_y  =  MMQ_Y_Q5_1_RDNA1;
+        nwarps = NWARPS_Q5_1_RDNA1;
+    } else if (compute_capability >= CC_TURING) {
         mmq_x  =  MMQ_X_Q5_1_AMPERE;
         mmq_y  =  MMQ_Y_Q5_1_AMPERE;
         nwarps = NWARPS_Q5_1_AMPERE;
@@ -4736,7 +5040,15 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
     const int compute_capability = g_compute_capabilities[id];
 
     int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_TURING) {
+    if (compute_capability >= CC_RDNA2) {
+        mmq_x  =  MMQ_X_Q8_0_RDNA2;
+        mmq_y  =  MMQ_Y_Q8_0_RDNA2;
+        nwarps = NWARPS_Q8_0_RDNA2;
+    } else if (compute_capability >= CC_OFFSET_AMD) {
+        mmq_x  =  MMQ_X_Q8_0_RDNA1;
+        mmq_y  =  MMQ_Y_Q8_0_RDNA1;
+        nwarps = NWARPS_Q8_0_RDNA1;
+    } else if (compute_capability >= CC_TURING) {
         mmq_x  =  MMQ_X_Q8_0_AMPERE;
         mmq_y  =  MMQ_Y_Q8_0_AMPERE;
         nwarps = NWARPS_Q8_0_AMPERE;
@@ -4773,7 +5085,15 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
     const int compute_capability = g_compute_capabilities[id];
 
     int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_TURING) {
+    if (compute_capability >= CC_RDNA2) {
+        mmq_x  =  MMQ_X_Q2_K_RDNA2;
+        mmq_y  =  MMQ_Y_Q2_K_RDNA2;
+        nwarps = NWARPS_Q2_K_RDNA2;
+    } else if (compute_capability >= CC_OFFSET_AMD) {
+        mmq_x  =  MMQ_X_Q2_K_RDNA1;
+        mmq_y  =  MMQ_Y_Q2_K_RDNA1;
+        nwarps = NWARPS_Q2_K_RDNA1;
+    } else if (compute_capability >= CC_TURING) {
         mmq_x  =  MMQ_X_Q2_K_AMPERE;
         mmq_y  =  MMQ_Y_Q2_K_AMPERE;
         nwarps = NWARPS_Q2_K_AMPERE;
@@ -4812,7 +5132,15 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
     const int compute_capability = g_compute_capabilities[id];
 
     int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_TURING) {
+    if (compute_capability >= CC_RDNA2) {
+        mmq_x  =  MMQ_X_Q3_K_RDNA2;
+        mmq_y  =  MMQ_Y_Q3_K_RDNA2;
+        nwarps = NWARPS_Q3_K_RDNA2;
+    } else if (compute_capability >= CC_OFFSET_AMD) {
+        mmq_x  =  MMQ_X_Q3_K_RDNA1;
+        mmq_y  =  MMQ_Y_Q3_K_RDNA1;
+        nwarps = NWARPS_Q3_K_RDNA1;
+    } else if (compute_capability >= CC_TURING) {
         mmq_x  =  MMQ_X_Q3_K_AMPERE;
         mmq_y  =  MMQ_Y_Q3_K_AMPERE;
         nwarps = NWARPS_Q3_K_AMPERE;
@@ -4850,7 +5178,15 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
     const int compute_capability = g_compute_capabilities[id];
 
     int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_TURING) {
+    if (compute_capability >= CC_RDNA2) {
+        mmq_x  =  MMQ_X_Q4_K_RDNA2;
+        mmq_y  =  MMQ_Y_Q4_K_RDNA2;
+        nwarps = NWARPS_Q4_K_RDNA2;
+    } else if (compute_capability >= CC_OFFSET_AMD) {
+        mmq_x  =  MMQ_X_Q4_K_RDNA1;
+        mmq_y  =  MMQ_Y_Q4_K_RDNA1;
+        nwarps = NWARPS_Q4_K_RDNA1;
+    } else if (compute_capability >= CC_TURING) {
         mmq_x  =  MMQ_X_Q4_K_AMPERE;
         mmq_y  =  MMQ_Y_Q4_K_AMPERE;
         nwarps = NWARPS_Q4_K_AMPERE;
@@ -4887,7 +5223,15 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
     const int compute_capability = g_compute_capabilities[id];
 
     int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_TURING) {
+    if (compute_capability >= CC_RDNA2) {
+        mmq_x  =  MMQ_X_Q5_K_RDNA2;
+        mmq_y  =  MMQ_Y_Q5_K_RDNA2;
+        nwarps = NWARPS_Q5_K_RDNA2;
+    } else if (compute_capability >= CC_OFFSET_AMD) {
+        mmq_x  =  MMQ_X_Q5_K_RDNA1;
+        mmq_y  =  MMQ_Y_Q5_K_RDNA1;
+        nwarps = NWARPS_Q5_K_RDNA1;
+    } else if (compute_capability >= CC_TURING) {
         mmq_x  =  MMQ_X_Q5_K_AMPERE;
         mmq_y  =  MMQ_Y_Q5_K_AMPERE;
         nwarps = NWARPS_Q5_K_AMPERE;
@@ -4924,7 +5268,15 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
     const int compute_capability = g_compute_capabilities[id];
 
     int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_TURING) {
+    if (compute_capability >= CC_RDNA2) {
+        mmq_x  =  MMQ_X_Q6_K_RDNA2;
+        mmq_y  =  MMQ_Y_Q6_K_RDNA2;
+        nwarps = NWARPS_Q6_K_RDNA2;
+    } else if (compute_capability >= CC_OFFSET_AMD) {
+        mmq_x  =  MMQ_X_Q6_K_RDNA1;
+        mmq_y  =  MMQ_Y_Q6_K_RDNA1;
+        nwarps = NWARPS_Q6_K_RDNA1;
+    } else if (compute_capability >= CC_TURING) {
         mmq_x  =  MMQ_X_Q6_K_AMPERE;
         mmq_y  =  MMQ_Y_Q6_K_AMPERE;
         nwarps = NWARPS_Q6_K_AMPERE;
@@ -5165,8 +5517,11 @@ void ggml_init_cublas() {
 
             g_tensor_split[id] = total_vram;
             total_vram += prop.totalGlobalMem;
-
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+            g_compute_capabilities[id] = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
+#else
             g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
         }
         for (int64_t id = 0; id < g_device_count; ++id) {
             g_tensor_split[id] /= total_vram;
@@ -5451,14 +5806,41 @@ inline void ggml_cuda_op_mul_mat_q(
 }
 
 static int64_t get_row_rounding(ggml_type type) {
-    int max_compute_capability = INT_MIN;
-    for (int id = 0; id < g_device_count; ++id) {
-        if (max_compute_capability < g_compute_capabilities[id]
-                && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
-            max_compute_capability = g_compute_capabilities[id];
+    int64_t min_compute_capability = INT_MAX;
+    int64_t max_compute_capability = INT_MIN;
+    for (int64_t id = 0; id < g_device_count; ++id) {
+        if (g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
+            if (min_compute_capability > g_compute_capabilities[id]) {
+                min_compute_capability = g_compute_capabilities[id];
+            }
+            if (max_compute_capability < g_compute_capabilities[id]) {
+                max_compute_capability = g_compute_capabilities[id];
+            }
         }
     }
 
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+    switch(type) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+            return max_compute_capability >= CC_RDNA2 ? 128 : 64;
+        case GGML_TYPE_F16:
+            return 1;
+        case GGML_TYPE_Q2_K:
+            return max_compute_capability >= CC_RDNA2 ? 128 : 32;
+        case GGML_TYPE_Q3_K:
+            return min_compute_capability < CC_RDNA2 ? 128 : 64;
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+            return max_compute_capability >= CC_RDNA2 ? 128 : 64;
+        default:
+            GGML_ASSERT(false);
+    }
+#else
     switch(type) {
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
@@ -5479,6 +5861,7 @@ static int64_t get_row_rounding(ggml_type type) {
         default:
             GGML_ASSERT(false);
     }
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 }
 
 inline void ggml_cuda_op_mul_mat_vec_q(