]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: Volta tensor core support for MMF (#16843)
authorJohannes Gäßler <redacted>
Fri, 31 Oct 2025 14:57:19 +0000 (15:57 +0100)
committerGitHub <redacted>
Fri, 31 Oct 2025 14:57:19 +0000 (15:57 +0100)
* CUDA: Volta tensor core support for MMF

* more generic checks for hardware support

* Update ggml/src/ggml-cuda/mmf.cuh

Co-authored-by: Aman Gupta <redacted>
---------

Co-authored-by: Aman Gupta <redacted>
ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/mma.cuh
ggml/src/ggml-cuda/mmf.cu
ggml/src/ggml-cuda/mmf.cuh

index 6a472be7fbb664e72a9300ddf966de4488eb98c8..ca876459d404da06bd96b8b91a77769790b65eca 100644 (file)
@@ -224,6 +224,11 @@ static const char * cu_get_error_str(CUresult err) {
 #define AMD_MFMA_AVAILABLE
 #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
 
+// The Volta instructions are in principle available on Turing or newer but they are effectively unusable:
+#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+#define VOLTA_MMA_AVAILABLE
+#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+
 #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
 #define TURING_MMA_AVAILABLE
 #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
@@ -278,7 +283,10 @@ static bool amd_mfma_available(const int cc) {
 #endif //!defined(GGML_HIP_NO_MMQ_MFMA)
 }
 
-// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
+static bool volta_mma_available(const int cc) {
+    return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA;
+}
+
 static bool turing_mma_available(const int cc) {
     return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
 }
index c1f24243fe3883745c6512520ab54fbc9d6cdc55..a7a28fd1ae66061cd7081af992a395678cbbf5a0 100644 (file)
 
 #include "common.cuh"
 
+// On Volta each warp is doing 4 8x8 mma operations in parallel.
+// The basic memory layout for a 32x8 output tile is to stack 4 input tiles in I direction and to mirror the B tile.
+// However, the i indices in this file are by default permuted to simplify the index calculations.
+// #define GGML_CUDA_MMA_NO_VOLTA_PERM
 
 #if CUDART_VERSION >= 11080
 
@@ -73,6 +77,15 @@ namespace ggml_cuda_mma {
         static constexpr int ne = I * J / 64;
         T x[ne] = {0};
 
+        static constexpr __device__ bool supported() {
+            if (I == 64 && J ==  2) return true;
+            if (I == 16 && J ==  8) return true;
+            if (I == 32 && J ==  4) return true;
+            if (I == 16 && J == 16) return true;
+            if (I == 32 && J == 32) return true;
+            return false;
+        }
+
         static __device__ __forceinline__ int get_i(const int l) {
             if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
                 return threadIdx.x % 16;
@@ -85,7 +98,8 @@ namespace ggml_cuda_mma {
             } else if constexpr (I == 32 && J == 32) {
                 return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
             } else {
-                static_assert(I == -1 && J == -1, "template specialization not implemented");
+                NO_DEVICE_CODE;
+                return -1;
             }
         }
 
@@ -101,22 +115,67 @@ namespace ggml_cuda_mma {
             } else if constexpr (I == 32 && J == 32) {
                 return threadIdx.x % 32;
             } else {
-                static_assert(I == -1 && J == -1, "template specialization not implemented");
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+        static constexpr int ne = I * J / 32;
+        T x[ne] = {0};
+
+        static constexpr __device__ bool supported() {
+            if (I == 32 && J ==  8) return true;
+            return false;
+        }
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            if constexpr (I == 32 && J == 8) {
+#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
+                return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (l & 2) | (threadIdx.x % 2);
+#else
+                return (l & 2) | (threadIdx.x & ~2);
+#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            if constexpr (I == 32 && J == 8) {
+                return (threadIdx.x & 2) | (l & (4 + 1));
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
             }
         }
 #else
         static constexpr int ne = I * J / 32;
         T x[ne] = {0};
 
+        static constexpr __device__ bool supported() {
+            if (I ==  8 && J ==  4) return true;
+            if (I ==  8 && J ==  8) return true;
+            if (I == 16 && J ==  8) return true;
+            if (I == 16 && J == 16) return true;
+            if (I == 32 && J ==  8) return true;
+            return false;
+        }
+
         static __device__ __forceinline__ int get_i(const int l) {
-            if constexpr (I == 8 && (J == 4 || J == 8)) {
+            if constexpr (I == 8 && J == 4) {
+                return threadIdx.x / 4;
+            } else if constexpr (I == 8 && J == 8) {
                 return threadIdx.x / 4;
             } else if constexpr (I == 16 && J == 8) {
-                return (l / 2) * 8 + threadIdx.x / 4;
+                return ((l / 2) * 8) | (threadIdx.x / 4);
             } else if constexpr (I == 16 && J == 16) {
-                return ((l / 2) % 2) * 8 + threadIdx.x / 4;
+                return (((l / 2) % 2) * 8) | (threadIdx.x / 4);
+            } else if constexpr (I == 32 && J == 8) {
+                return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction.
             } else {
-                static_assert(I == -1 && J == -1, "template specialization not implemented");
+                NO_DEVICE_CODE;
+                return -1;
             }
         }
 
@@ -124,13 +183,16 @@ namespace ggml_cuda_mma {
             if constexpr (I == 8 && J == 4) {
                 return threadIdx.x % 4;
             } else if constexpr (I == 8 && J == 8) {
-                return 4 * l + threadIdx.x % 4;
+                return (l * 4) | (threadIdx.x % 4);
             } else if constexpr (I == 16 && J == 8) {
-                return 2 * (threadIdx.x % 4) + l % 2;
+                return ((threadIdx.x % 4) * 2) | (l % 2);
             } else if constexpr (I == 16 && J == 16) {
-                return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2;
+                return ((l / 4) * 8) | ((threadIdx.x % 4) * 2) | (l % 2);
+            } else if constexpr (I == 32 && J == 8) {
+                return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction.
             } else {
-                static_assert(I == -1 && J == -1, "template specialization not implemented");
+                NO_DEVICE_CODE;
+                return -1;
             }
         }
 #endif // defined(GGML_USE_HIP)
@@ -140,32 +202,83 @@ namespace ggml_cuda_mma {
     struct tile<I_, J_, half2> {
         static constexpr int I  = I_;
         static constexpr int J  = J_;
+
+#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+        static constexpr int ne = I == 8 && J == 8 ? I * J / (WARP_SIZE/4) : I * J / WARP_SIZE;
+        half2 x[ne] = {{0.0f, 0.0f}};
+
+        static constexpr __device__ bool supported() {
+            if (I ==  8 && J ==  8) return true;
+            if (I == 32 && J ==  8) return true;
+            return false;
+        }
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            if constexpr (I == 8 && J == 8) {
+                return ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
+            } else if constexpr (I == 32 && J == 8) {
+#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
+                return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
+#else
+                return threadIdx.x;
+#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            if constexpr ((I == 8 || I == 32) && J == 8) {
+                return l;
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+#else
         static constexpr int ne = I * J / WARP_SIZE;
         half2 x[ne] = {{0.0f, 0.0f}};
 
+        static constexpr __device__ bool supported() {
+            if (I ==  8 && J ==  4) return true;
+            if (I ==  8 && J ==  8) return true;
+            if (I == 16 && J ==  8) return true;
+            if (I == 16 && J == 16) return true;
+            if (I == 32 && J ==  8) return true;
+            return false;
+        }
+
         static __device__ __forceinline__ int get_i(const int l) {
             if constexpr (I == 8 && J == 8) {
                 return threadIdx.x / 4;
             } else if constexpr (I == 16 && J == 4) {
-                return l * 8 + threadIdx.x / 4;
+                return (l * 8) | (threadIdx.x / 4);
             } else if constexpr (I == 16 && J == 8) {
-                return (l % 2) * 8 + threadIdx.x / 4;
+                return ((l % 2) * 8) | (threadIdx.x / 4);
+            } else if constexpr (I == 32 && J == 8) {
+                return ((l / 4) * 16) | ((l % 2) * 8) | (threadIdx.x / 4);
             } else {
-                static_assert(I == -1 && J == -1, "template specialization not implemented");
+                NO_DEVICE_CODE;
+                return -1;
             }
         }
 
         static __device__ __forceinline__ int get_j(const int l) {
             if constexpr (I == 8 && J == 8) {
-                return l * 4 + threadIdx.x % 4;
+                return (l * 4) | (threadIdx.x % 4);
             } else if constexpr (I == 16 && J == 4) {
                 return threadIdx.x % 4;
             } else if constexpr (I == 16 && J == 8) {
-                return (l / 2) * 4 + threadIdx.x % 4;
+                return ((l / 2) * 4) | (threadIdx.x % 4);
+            } else if constexpr (I == 32 && J == 8) {
+                return ((l & 2) * 2) | (threadIdx.x % 4);
             } else {
-                static_assert(I == -1 && J == -1, "template specialization not implemented");
+                NO_DEVICE_CODE;
+                return -1;
             }
         }
+#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
     };
 
     template <int I_, int J_>
@@ -175,27 +288,36 @@ namespace ggml_cuda_mma {
         static constexpr int ne = I * J / WARP_SIZE;
         nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
 
+        static constexpr __device__ bool supported() {
+            if (I ==  8 && J ==  8) return true;
+            if (I == 16 && J ==  4) return true;
+            if (I == 16 && J ==  8) return true;
+            return false;
+        }
+
         static __device__ __forceinline__ int get_i(const int l) {
             if constexpr (I == 8 && J == 8) {
                 return threadIdx.x / 4;
             } else if constexpr (I == 16 && J == 4) {
-                return l * 8 + threadIdx.x / 4;
+                return (l * 8) | (threadIdx.x / 4);
             } else if constexpr (I == 16 && J == 8) {
-                return (l % 2) * 8 + threadIdx.x / 4;
+                return ((l % 2) * 8) | (threadIdx.x / 4);
             } else {
-                static_assert(I == -1 && J == -1, "template specialization not implemented");
+                NO_DEVICE_CODE;
+                return -1;
             }
         }
 
         static __device__ __forceinline__ int get_j(const int l) {
             if constexpr (I == 8 && J == 8) {
-                return l * 4 + threadIdx.x % 4;
+                return (l * 4) | (threadIdx.x % 4);
             } else if constexpr (I == 16 && J == 4) {
                 return threadIdx.x % 4;
             } else if constexpr (I == 16 && J == 8) {
-                return (l / 2) * 4 + threadIdx.x % 4;
+                return ((l / 2) * 4) | (threadIdx.x % 4);
             } else {
-                static_assert(I == -1 && J == -1, "template specialization not implemented");
+                NO_DEVICE_CODE;
+                return -1;
             }
         }
     };
@@ -263,8 +385,12 @@ namespace ggml_cuda_mma {
             : "=r"(xi[0]), "=r"(xi[1])
             : "l"(xs));
 #else
-        load_generic(xs0, stride);
-        GGML_UNUSED(t);
+#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+        GGML_UNUSED_VARS(t, xs0, stride);
+        NO_DEVICE_CODE;
+#else
+        load_generic(t, xs0, stride);
+#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
 #endif // TURING_MMA_AVAILABLE
     }
 
@@ -277,11 +403,35 @@ namespace ggml_cuda_mma {
         asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
             : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
             : "l"(xs));
+#else
+#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+        GGML_UNUSED_VARS(t, xs0, stride);
+        NO_DEVICE_CODE;
 #else
         load_generic(t, xs0, stride);
+#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
 #endif // TURING_MMA_AVAILABLE
     }
 
+    template <typename T>
+    static __device__ __forceinline__ void load_ldmatrix(
+            tile<32, 8, T> & t, const T * __restrict__ xs0, const int stride) {
+#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+#if 1
+        // TODO: more generic handling
+        static_assert(sizeof(T) == 4, "bad type size");
+        ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0);
+        ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4);
+#else
+        load_generic(t, xs0, stride);
+#endif // 1
+#else
+        tile<16, 8, T> * t16 = (tile<16, 8, T> *) &t;
+        load_ldmatrix(t16[0], xs0 +  0*stride, stride);
+        load_ldmatrix(t16[1], xs0 + 16*stride, stride);
+#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+    }
+
     template <typename T>
     static __device__ __forceinline__ void load_ldmatrix_trans(
             tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
@@ -546,4 +696,43 @@ namespace ggml_cuda_mma {
         NO_DEVICE_CODE;
 #endif // AMD_MFMA_AVAILABLE
     }
+
+    template <typename T1, typename T2, int J, int K>
+    static __device__ __forceinline__ void mma(
+            tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile<J, K, T2> & B) {
+        tile<16, J, T1> * D16 = (tile<16, J, T1> *) &D;
+        tile<16, K, T2> * A16 = (tile<16, K, T2> *) &A;
+        mma(D16[0], A16[0], B);
+        mma(D16[1], A16[1], B);
+    }
+
+    static __device__ __forceinline__ void mma(
+            tile<32, 8, float> & D, const tile<32, 8, half2> & A, const tile<8, 8, half2> & B) {
+#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+        const int * Axi = (const int *) A.x;
+        const int * Bxi = (const int *) B.x;
+        int       * Dxi = (int       *) D.x;
+        asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
+            "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
+        asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
+            "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
+            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
+        asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
+            "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
+            : "r"(Axi[4]), "r"(Axi[5]), "r"(Bxi[4]), "r"(Bxi[5]));
+        asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
+            "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
+            : "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7]));
+#else
+        tile<16, 8, float> * D16 = (tile<16, 8, float> *) &D;
+        tile<16, 8, half2> * A16 = (tile<16, 8, half2> *) &A;
+        mma(D16[0], A16[0], B);
+        mma(D16[1], A16[1], B);
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+    }
 }
index 9e2aaf52d6ccefc2ad0122eb6ab097dc191bec3d..2b0a61395b4588907d4499d71072a50d5e2fcaae 100644 (file)
@@ -148,7 +148,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
         case GGML_TYPE_F32:
             return ampere_mma_available(cc);
         case GGML_TYPE_F16:
-            return turing_mma_available(cc);
+            return volta_mma_available(cc) || turing_mma_available(cc);
         case GGML_TYPE_BF16:
             return ampere_mma_available(cc);
         default:
index 49d5295be0ea0515d78aeb011129aa8da6648066..f7e46e2f63b2fdf815582c1d043077adb966481b 100644 (file)
@@ -28,9 +28,19 @@ static __global__ void mul_mat_f(
         const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
         const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
 #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
-    typedef tile<16, 8, T>     tile_A;
-    typedef tile< 8, 8, T>     tile_B;
-    typedef tile<16, 8, float> tile_C;
+    constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
+    constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
+
+    if (!I_16_supported && !I_32_supported) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
+    constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
+
+    typedef tile<I_preferred, 8, T>     tile_A;
+    typedef tile<8,           8, T>     tile_B;
+    typedef tile<I_preferred, 8, float> tile_C;
 
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     constexpr int tile_k_padded = warp_size + 4;
@@ -232,7 +242,6 @@ static __global__ void mul_mat_f(
 #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
 }
 
-
 //This kernel is for larger batch sizes of mul_mat_id
 template <typename T, int rows_per_block, int cols_per_block, int nwarps>
 __launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
@@ -245,9 +254,19 @@ static __global__ void mul_mat_f_ids(
         const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
         const uint3 sis1_fd, const uint3 nch_fd) {
 #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
-    typedef tile<16, 8, T>     tile_A;
-    typedef tile< 8, 8, T>     tile_B;
-    typedef tile<16, 8, float> tile_C;
+    constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
+    constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
+
+    if (!I_16_supported && !I_32_supported) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
+    constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work butr 16 is ~1% faster.
+
+    typedef tile<I_preferred, 8, T>     tile_A;
+    typedef tile<8,           8, T>     tile_B;
+    typedef tile<I_preferred, 8, float> tile_C;
 
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     constexpr int tile_k_padded = warp_size + 4;
@@ -533,7 +552,8 @@ void mul_mat_f_cuda(
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
         const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
         cudaStream_t stream, const mmf_ids_data * ids_data) {
-    typedef tile<16, 8, T>     tile_A;
+    typedef tile<16, 8, T>     tile_A_16;
+    typedef tile<32, 8, T>     tile_A_32;
     typedef tile< 8, 8, T>     tile_B;
 
     GGML_ASSERT(ncols_x      % 2 == 0);
@@ -544,7 +564,8 @@ void mul_mat_f_cuda(
     const int64_t channel_ratio = nchannels_dst / nchannels_x;
     const int64_t sample_ratio  = nsamples_dst  / nsamples_x;
 
-    const int device = ggml_cuda_get_device();
+    const int device    = ggml_cuda_get_device();
+    const int cc        = ggml_cuda_info().devices[device].cc;
     const int warp_size = ggml_cuda_info().devices[device].warp_size;
 
     int64_t nwarps_best     = 1;
@@ -559,7 +580,7 @@ void mul_mat_f_cuda(
     }
 
     constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
-    const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4;
+    const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
     const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
     const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
     const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;