using namespace ggml_cuda_mma;
-typedef tile<16, 8, half2> tile_A;
-typedef tile< 8, 8, half2> tile_B;
-typedef tile<16, 8, half2> tile_B_16;
-typedef tile<16, 8, float> tile_C_KQ;
-typedef tile<16, 16, float> tile_C_KQ_16;
-typedef tile<16, 4, half2> tile_C_VKQ;
-typedef tile<16, 8, half2> tile_C_VKQ_16;
-
-// Config options for specific head sizes.
+// Config options for the MMA kernel.
// Should not affect results, only speed/register pressure/shared memory use.
-//
-// nbatch_fa: number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators.
-// nwarps_max: maximum number of warps per CUDA block, up to 8 warps in total can run per SM (given enough shared memory).
-// Q_in_reg: whether the Q values should be kept permanently in registers.
-// nstages_target: targeted number of pipeline stages for cp_async (if available), 0 means synchronous data loading.
-// nbatch_K2: number of K half2 values in direction of DKQ to load in parallel.
-// nbatch_V2: number of V half2 values in direction of DV to load in parallel.
-// nbatch_combine: number of VKQ half2 values in direction of DV to combine in parallel.
-
-template <int DKQ, int DV>
-struct fattn_mma_f16_config;
-
-template <>
-struct fattn_mma_f16_config< 64, 64> {
- static constexpr int nbatch_fa = 64;
- static constexpr int nwarps_max = 4;
- static constexpr bool Q_in_reg = true;
- static constexpr int nstages_target = 2;
-
- static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
- return 32;
- }
-
- static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
- return 32;
- }
-
- static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
- return 32;
- }
-
- static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
- return 32;
- }
-
- static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
- return 32;
- }
-
- static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
- return 32;
- }
+struct fattn_mma_config {
+ int nthreads; // Number of threads per CUDA block.
+ int occupancy; // Targeted occupancy for the MMA kernel.
+ int nbatch_fa; // Number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators.
+ int nbatch_K2; // Number of K half2 values in direction of DKQ to load in parallel.
+ int nbatch_V2; // Number of V half2 values in direction of DV to load in parallel.
+ int nbatch_combine; // Number of VKQ half2 values in direction of DV to combine in parallel.
+ int nstages_target; // Number of pipeline stages to use ideally, 1 == always load data synchronously, 2 == preload data if there is hardware support.
+ bool Q_in_reg; // Whether the Q values should be kept permanently in registers.
+
+ constexpr __host__ __device__ fattn_mma_config(
+ int nthreads, int occupancy, int nbatch_fa, int nbatch_K2, int nbatch_V2, int nbatch_combine, int nstages_target, bool Q_in_reg) :
+ nthreads(nthreads), occupancy(occupancy), nbatch_fa(nbatch_fa), nbatch_K2(nbatch_K2), nbatch_V2(nbatch_V2), nbatch_combine(nbatch_combine),
+ nstages_target(nstages_target), Q_in_reg(Q_in_reg) {}
};
-template <>
-struct fattn_mma_f16_config< 80, 80> {
- static constexpr int nbatch_fa = 64;
- static constexpr int nwarps_max = 4;
- static constexpr bool Q_in_reg = true;
- static constexpr int nstages_target = 2;
-
- static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
- return 40;
- }
-
- static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
- return 40;
- }
-
- static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
- return 40;
- }
-
- static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
- return 40;
- }
-
- static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
- return 40;
- }
-
- static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
- return 40;
- }
-};
-
-template <>
-struct fattn_mma_f16_config< 96, 96> {
- static constexpr int nbatch_fa = 64;
- static constexpr int nwarps_max = 4;
- static constexpr bool Q_in_reg = true;
- static constexpr int nstages_target = 2;
-
- static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
- return 48;
- }
-
- static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
- return 48;
- }
-
- static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
- return 48;
- }
-
- static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
- return 48;
- }
-
- static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
- return 48;
- }
+#define GGML_CUDA_FATTN_MMA_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads_, occupancy_, nbatch_fa_, nbatch_K2_, nbatch_V2_, nbatch_combine_, nstages_target_, Q_in_reg_) \
+ if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \
+ static_assert((nthreads_) % 32 == 0 && (nthreads_) <= 512, "bad nthreads"); \
+ static_assert( (occupancy_) <= 8, "bad occupancy"); \
+ static_assert((nbatch_fa_) % 32 == 0 && (nbatch_fa_) <= 256, "bad nbatch_fa"); \
+ static_assert((nbatch_K2_) % 4 == 0 && (nbatch_K2_) <= 512, "bad nbatch_K2"); \
+ static_assert((nbatch_V2_) % 4 == 0 && (nbatch_V2_) <= 256, "bad nbatch_V2"); \
+ static_assert((nbatch_combine_) % 4 == 0 && (nbatch_combine_) <= 128, "bad nbatch_combine"); \
+ static_assert((nstages_target_) >= 1 && (nstages_target_) <= 2, "bad nstages_target"); \
+ return fattn_mma_config{(nthreads_), (occupancy_), (nbatch_fa_), (nbatch_K2_), (nbatch_V2_), (nbatch_combine_), (nstages_target_), (Q_in_reg_)}; \
+ } \
+
+static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_ampere(const int DKQ, const int DV, const int ncols) {
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 128, 2, 64, 32, 32, 32, 2, true);
+
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 128, 2, 64, 40, 40, 40, 2, true);
+
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 128, 2, 64, 48, 48, 48, 2, true);
+
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 128, 2, 64, 56, 56, 56, 2, true);
+
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 2, true);
+
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
+
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
+
+ return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
+}
- static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
- return 48;
- }
-};
+static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_turing(const int DKQ, const int DV, const int ncols) {
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 128, 2, 64, 128, 128, 128, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
-template <>
-struct fattn_mma_f16_config<112, 112> {
- static constexpr int nbatch_fa = 64;
- static constexpr int nwarps_max = 4;
- static constexpr bool Q_in_reg = true;
- static constexpr int nstages_target = 2;
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
- static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
- return 56;
- }
+ return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
+}
- static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
- return 56;
- }
+static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);
- static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
- return 56;
- }
+ // TODO tune specifically for Volta
+ return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
+}
- static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
- return 56;
+static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
+ if (ampere_mma_available(cc)) {
+ return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
}
-
- static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
- return 56;
+ if (turing_mma_available(cc)) {
+ return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
}
+ GGML_ASSERT(volta_mma_available(cc));
+ return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
+}
- static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
- return 56;
- }
-};
+static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols) {
+#if defined(AMPERE_MMA_AVAILABLE)
+ return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
+#elif defined(TURING_MMA_AVAILABLE)
+ return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
+#elif defined(VOLTA_MMA_AVAILABLE)
+ return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
+#else
+ GGML_UNUSED_VARS(DKQ, DV, ncols);
+ return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
+#endif // defined(AMPERE_MMA_AVAILABLE)
+}
-template <>
-struct fattn_mma_f16_config<128, 128> {
- static constexpr int nbatch_fa = 64;
- static constexpr int nwarps_max = 4;
- static constexpr bool Q_in_reg = true;
- static constexpr int nstages_target = 2;
+static __host__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nthreads;
+}
- static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
- return 64;
- }
+static constexpr __device__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nthreads;
+}
- static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
- return 64;
- }
+static __host__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).occupancy;
+}
- static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
- return 64;
- }
+static constexpr __device__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).occupancy;
+}
- static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
- return 64;
- }
+static __host__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_fa;
+}
- static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
- return 64;
- }
+static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_fa;
+}
- static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
- return 64;
- }
-};
+static __host__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols, const int cc) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_K2;
+}
-template <>
-struct fattn_mma_f16_config<256, 256> {
- static constexpr int nbatch_fa = 32;
- static constexpr int nwarps_max = 4;
- static constexpr bool Q_in_reg = true;
- static constexpr int nstages_target = 2;
+static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_K2;
+}
- static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
- return 128;
- }
+static __host__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols, const int cc) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_V2;
+}
- static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
- return 128;
- }
+static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_V2;
+}
- static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
- return 128;
- }
+static __host__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols, const int cc) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_combine;
+}
- static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
- return 128;
- }
+static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_combine;
+}
- static int get_nbatch_combine_host(const int cc, const int ncols) {
- if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
- return ncols <= 16 ? 128 : 64;
- }
- return 64;
- }
+static __host__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols, const int cc) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nstages_target;
+}
- static constexpr __device__ int get_nbatch_combine_device(int ncols) {
-#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
- return ncols <= 16 ? 128 : 64;
-#else
- GGML_UNUSED(ncols);
- return 128;
-#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
- }
-};
+static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nstages_target;
+}
-template <>
-struct fattn_mma_f16_config<576, 512> {
- static constexpr int nbatch_fa = 32;
- static constexpr int nwarps_max = 8;
- static constexpr bool Q_in_reg = false;
- static constexpr int nstages_target = 1;
+static __host__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols, const int cc) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).Q_in_reg;
+}
- static int get_nbatch_K2_host(const int cc, const int ncols) {
- if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
- return ncols <= 16 ? 96 : 160;
- }
- return ncols <= 16 ? 288 : 160;
- }
+static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg;
+}
- static constexpr __device__ int get_nbatch_K2_device(int ncols) {
-#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
- return ncols <= 16 ? 96 : 160;
-#else
- return ncols <= 16 ? 288 : 160;
-#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
- }
+// ------------------------------------------------------------------------------------------------------------------
- static int get_nbatch_V2_host(const int cc, const int ncols) {
- if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
- return ncols <= 16 ? 64 : 128;
- }
- return ncols <= 16 ? 256 : 128;
- }
+static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) {
+ return cp_async_available(cc) && ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2, cc) : 0;
+}
- static constexpr __device__ int get_nbatch_V2_device(int ncols) {
-#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
- return ncols <= 16 ? 64 : 128;
+static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2) {
+#ifdef CP_ASYNC_AVAILABLE
+ return ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2) : 0;
#else
- return ncols <= 16 ? 256 : 128;
-#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
- }
-
- static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
- return 128;
- }
-
- static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
- return 128;
- }
-};
+ GGML_UNUSED_VARS(DKQ, DV, ncols1, ncols2);
+ return 0;
+#endif // CP_ASYNC_AVAILABLE
+}
// ------------------------------------------------------------------------------------------------------------------
-template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async>
+template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
- const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) {
-
+ const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) {
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
// The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
-
- if (use_cp_async) {
+ if constexpr (use_cp_async) {
+ static_assert(!oob_check, "OOB check not compatible with cp_async");
constexpr int preload = 64;
constexpr int h2_per_chunk = 16/sizeof(half2);
const int chunks_per_row = D2 / h2_per_chunk;
}
}
};
- ggml_cuda_unroll<5>{}(load);
+ // 1: max 32*16=512 bytes, 256 half
+ // 2: max 16*16=256 bytes, 128 half
+ // 3: max 8*16=128 bytes, 64 half
+ // 4: max 4*16= 64 bytes, 32 half
+ // 5: max 2*16= 32 bytes, 16 half
+ // 6: max 1*16= 16 bytes, 8 half
+ ggml_cuda_unroll<6>{}(load);
} else {
- static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds");
+ // TODO use ggml_cuda_memcpy_1
auto load = [&] __device__ (const int n) {
const int stride_k = WARP_SIZE >> n;
const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
- tile_KV[i*stride_tile + k] = KV[i*stride_KV + k];
+ tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f);
}
}
};
- ggml_cuda_unroll<3>{}(load);
+ // 1: max 32* 4=128 bytes, 64 half
+ // 2: max 16* 4= 64 bytes, 32 half
+ // 3: max 8* 4= 32 bytes, 16 half
+ // 4: max 4* 4= 16 bytes, 8 half
+ ggml_cuda_unroll<4>{}(load);
}
}
-template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async>
+template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
- const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) {
- static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter");
-
- if (use_cp_async) {
+ const half * const __restrict__ mask_h, half * const __restrict__ tile_mask,
+ const int stride_mask, const int i_sup, const int j0, const uint3 ne01) {
+ if constexpr (use_cp_async) {
+ static_assert(nbatch_fa <= 8*WARP_SIZE && nbatch_fa % 8 == 0, "bad nbatch_fa");
+ static_assert(!oob_check, "OOB check incompatible with cp_async");
constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
constexpr int stride_j = nwarps * cols_per_warp;
const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);
#pragma unroll
- for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
- const int j = j0 + threadIdx.y*cols_per_warp +
- (nbatch_fa == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/cols_per_warp));
+ for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
+ const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
+ const int j_vram = fastmodulo(j0 + j_sram, ne01);
- if (j0 + stride_j > ncols1 && j >= ncols1) {
+ if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
break;
}
- const int i = 4 * (threadIdx.x % (nbatch_fa/8));
+ const int i = 8 * (threadIdx.x % (nbatch_fa/8));
- cp_async_cg_16<preload>(tile_mask_32 + j*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i);
+ cp_async_cg_16<preload>(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j_vram*stride_mask + i);
}
- return;
- }
+ } else if constexpr (oob_check) {
+#pragma unroll
+ for (int j1 = 0; j1 < ncols1; j1 += nwarps) {
+ const int j_sram = j1 + threadIdx.y;
+ const int j_vram = fastmodulo(j0 + j_sram, ne01);
+
+ if (j1 + nwarps > ncols1 && j_sram >= ncols1) {
+ break;
+ }
- constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
- constexpr int stride_j = nwarps * cols_per_warp;
#pragma unroll
- for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
- const int j = j0 + threadIdx.y*cols_per_warp + (nbatch_fa == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/cols_per_warp));
+ for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
- if (j0 + stride_j > ncols1 && j >= ncols1) {
- break;
+ tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f);
+ }
}
+ } else if constexpr (nbatch_fa < 2*WARP_SIZE) {
+ constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
+ constexpr int stride_j = nwarps * cols_per_warp;
+#pragma unroll
+ for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
+ const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
+ const int j_vram = fastmodulo(j0 + j_sram, ne01);
- const int i = nbatch_fa == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/cols_per_warp);
+ if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
+ break;
+ }
+
+ const int i = threadIdx.x % (WARP_SIZE/cols_per_warp);
- tile_mask[j*(nbatch_fa/2 + 4) + i] = mask_h2[j*stride_mask + i];
+ ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i);
+ }
+ } else {
+#pragma unroll
+ for (int j1 = 0; j1 < ncols1; j1 += nwarps) {
+ const int j_sram = j1 + threadIdx.y;
+ const int j_vram = fastmodulo(j0 + j_sram, ne01);
+
+ if (j1 + nwarps > ncols1 && j_sram >= ncols1) {
+ break;
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) {
+ const int i = i0 + 2*threadIdx.x;
+
+ ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i);
+ }
+ }
}
}
-template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles,
- bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
+template<int DKQ, int DV, int ncols1, int ncols2, int nwarps,
+ bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check,
+ typename T_A_KQ, typename T_B_KQ, typename T_C_KQ, typename T_A_VKQ, typename T_B_VKQ, typename T_C_VKQ>
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const float2 * const __restrict__ Q_f2,
const half2 * const __restrict__ K_h2,
const half2 * const __restrict__ V_h2,
- const half2 * const __restrict__ mask_h2,
+ const half * const __restrict__ mask_h,
float2 * const __restrict__ dstk,
float2 * const __restrict__ dstk_fixup,
const float scale,
const float slope,
const float logit_softcap,
- const int ne01,
+ const uint3 ne01,
const int ne02,
const int stride_K,
const int stride_V,
half2 * const __restrict__ tile_Q,
half2 * const __restrict__ tile_K,
half2 * const __restrict__ tile_V,
- half2 * const __restrict__ tile_mask,
- const tile_B * const __restrict__ Q_B,
- tile_C_VKQ * const __restrict__ VKQ_C,
+ half * const __restrict__ tile_mask,
+ T_B_KQ * const __restrict__ Q_B,
+ T_C_VKQ * const __restrict__ VKQ_C,
float * const __restrict__ KQ_max,
float * const __restrict__ KQ_rowsum,
- const int kb0) {
-#ifdef TURING_MMA_AVAILABLE
- typedef fattn_mma_f16_config<DKQ, DV> c;
-
-#ifdef CP_ASYNC_AVAILABLE
- constexpr int nstages = c::nstages_target;
-#else
- constexpr int nstages = 0;
-#endif // CP_ASYNC_AVAILABLE
-
- constexpr int cols_per_warp = ntiles * tile_B::I;
- constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
- constexpr int ncols = ncols1 * ncols2;
- constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
- constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
+ const int jt,
+ const int kb0,
+ const int k_VKQ_sup) {
+#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+ constexpr int ncols = ncols1 * ncols2;
+ constexpr int cols_per_warp = T_B_KQ::I;
+ constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
+ constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+ constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
+ constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
+ constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
+ constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols);
+ constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2);
constexpr int stride_tile_Q = DKQ/2 + 4;
constexpr int stride_tile_K = nbatch_K2 + 4;
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
- const int k_VKQ_0 = kb0 * c::nbatch_fa;
- tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles];
-
- // Use wide variants of tiles if ntiles >= 2.
- tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B;
- tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
- tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
+ const int k_VKQ_0 = kb0 * nbatch_fa;
+#if defined(TURING_MMA_AVAILABLE)
+ T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))];
+#else // Volta
+ T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
+#endif // defined(TURING_MMA_AVAILABLE)
if constexpr (nstages > 1) {
+ static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline");
static_assert(!mla, "multi-stage loading not implemented for MLA");
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
constexpr bool use_cp_async = true;
cp_async_wait_all();
__syncthreads();
- flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
- (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V);
+ flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>
+ (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V, k_VKQ_sup);
} else {
constexpr bool use_cp_async = nstages == 1;
- if (ncols2 > 1 || mask_h2) {
- flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask);
+ if (ncols2 > 1 || mask_h) {
+ flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
+ (mask_h + k_VKQ_0, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
}
}
const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
const int k0_diff = k0_stop - k0_start;
- if (nstages <= 1) {
+ if constexpr (nstages <= 1) {
constexpr bool use_cp_async = nstages == 1;
- flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
- (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K);
+ flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
+ (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K, k_VKQ_sup);
if (use_cp_async) {
cp_async_wait_all();
}
}
// Calculate tile of KQ:
- if constexpr (c::Q_in_reg) {
+ if constexpr (Q_in_reg) {
#pragma unroll
- for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) {
- const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
+ for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) {
+ const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I;
#pragma unroll
- for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) {
- tile_A K_A;
+ for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
+ T_A_KQ K_A;
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
- if (ntiles == 1) {
- mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]);
+ if constexpr (cols_per_warp == 8) {
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
} else {
-#pragma unroll
- for (int t = 0; t < ntiles/2; ++t) {
- // Wide version of KQ_C is column-major => swap A and B.
- mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A);
- }
+ // Wide version of KQ_C is column-major => swap A and B.
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A);
}
}
}
} else {
- static_assert(ntiles == 2, "ntiles != 2 not implemented");
+ static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
#pragma unroll
- for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) {
- load_ldmatrix(Q_B_16[0], tile_Q + (threadIdx.y / np)*(tile_B_16::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
+ for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
+ load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
#pragma unroll
- for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) {
- const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
+ for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) {
+ const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I;
- tile_A K_A;
+ T_A_KQ K_A;
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
// Wide version of KQ_C is column-major => swap A and B.
- mma(KQ_C_16[i_KQ_00/(np*tile_A::I)], Q_B_16[0], K_A);
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
}
}
}
- if (nstages <= 1) {
+ if constexpr (nstages <= 1) {
__syncthreads(); // Only needed if tile_K == tile_V.
}
}
if (use_logit_softcap) {
- static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
+ constexpr int stride = cols_per_warp == 8 ? np*T_C_KQ::I : np*T_C_KQ::J;
+ static_assert(nbatch_fa % stride == 0, "bad loop size");
#pragma unroll
- for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) {
+ for (int i = 0; i < nbatch_fa/stride; ++i) {
#pragma unroll
- for (int l = 0; l < tile_C_KQ::ne; ++l) {
+ for (int l = 0; l < T_C_KQ::ne; ++l) {
KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
}
}
}
float KQ_rowsum_add[cols_per_thread] = {0.0f};
- if (ntiles == 1) {
- if (ncols2 > 1 || mask_h2) {
+ if constexpr (cols_per_warp == 8) {
+ if (ncols2 > 1 || mask_h) {
#pragma unroll
- for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) {
- const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
+ for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::I) {
+ const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::I;
#pragma unroll
- for (int l = 0; l < tile_C_KQ::ne; ++l) {
- const int i = i0 + tile_C_KQ::get_i(l);
- const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2;
+ for (int l = 0; l < T_C_KQ::ne; ++l) {
+ const int i = i0 + T_C_KQ::get_i(l);
+ const int j = ((threadIdx.y / np)*T_C_KQ::J + T_C_KQ::get_j(l)) / ncols2;
- KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope *
- __half2float(((const half *) tile_mask)[j*(c::nbatch_fa + 8) + i]);
+ KQ_C[i00/(np*T_C_KQ::I)].x[l] += slope * __half2float(tile_mask[j*(nbatch_fa + 8) + i]);
}
}
}
// Calculate softmax for each KQ column using the current max. value.
// The divisor is stored in KQ_rowsum and will be applied at the end.
- static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
+ static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size");
#pragma unroll
- for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) {
+ for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) {
#pragma unroll
- for (int l = 0; l < tile_C_KQ::ne; ++l) {
- KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]);
+ for (int l = 0; l < T_C_KQ::ne; ++l) {
+ if (!oob_check || k0 + T_C_KQ::get_i(l) < k_VKQ_sup) {
+ KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l]);
+ }
}
}
- // Values per KQ column are spread across 8 threads, does not need full warp reduce:
+ // Values per KQ column are spread across 8 threads:
#pragma unroll
for (int col = 0; col < cols_per_thread; ++col) {
#pragma unroll
}
}
- static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
+ static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size");
#pragma unroll
- for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) {
+ for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) {
#pragma unroll
- for (int l = 0; l < tile_C_KQ::ne; ++l) {
- KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]);
-
- KQ_rowsum_add[l % 2] += KQ_C[k].x[l];
+ for (int l = 0; l < T_C_KQ::ne; ++l) {
+ if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
+ KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[l % 2]);
+ KQ_rowsum_add[l % 2] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
+ } else {
+ KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f;
+ }
}
}
- } else { // ntiles > 1
- if (ncols2 > 1 || mask_h2) {
+ } else { // not Turing mma or T_B_KQ::I > 8
+ if (ncols2 > 1 || mask_h) {
#pragma unroll
- for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ_16::J) {
- const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J;
+ for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::J) {
+ const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::J;
#pragma unroll
- for (int t = 0; t < ntiles/2; ++t) {
-#pragma unroll
- for (int l0 = 0; l0 < tile_C_KQ_16::ne; l0 += 2) {
- const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2;
- const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2;
+ for (int l0 = 0; l0 < T_C_KQ::ne; l0 += 2) {
+ const int i = (i0 + T_C_KQ::get_j(l0)) / 2;
+ const int j = ((threadIdx.y / np)*cols_per_warp + T_C_KQ::get_i(l0)) / ncols2;
- const float2 tmp = __half22float2(tile_mask[j*(c::nbatch_fa/2 + 4) + i]);
- const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t;
- KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x;
- KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y;
- }
+ const float2 tmp = __half22float2(((const half2 *)tile_mask)[j*(nbatch_fa/2 + 4) + i]);
+ KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 0] += slope*tmp.x;
+ KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 1] += slope*tmp.y;
}
}
}
// Calculate softmax for each KQ column using the current max. value.
// The divisor is stored in KQ_rowsum and will be applied at the end.
- static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
-#pragma unroll
- for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) {
+ static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size");
#pragma unroll
- for (int t = 0; t < ntiles/2; ++t) {
+ for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
#pragma unroll
- for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
- const int KQ_index = 2*t + (l/2) % 2;
- KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]);
+ for (int l = 0; l < T_C_KQ::ne; ++l) {
+ if (!oob_check || k0 + T_C_KQ::get_j(l) < k_VKQ_sup) {
+ // Turing + Volta:
+ KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l]);
}
}
}
- // Values per KQ column are spread across 4 threads, does not need full warp reduce:
#pragma unroll
for (int col = 0; col < cols_per_thread; ++col) {
+#if defined(TURING_MMA_AVAILABLE)
+ // Values per KQ column are spread across 4 threads:
+ constexpr int offset_first = 2;
+ constexpr int offset_last = 1;
+#else
+ // Values per KQ column are spread across 2 threads:
+ constexpr int offset_first = 2;
+ constexpr int offset_last = 2;
+#endif // defined(TURING_MMA_AVAILABLE)
#pragma unroll
- for (int offset = 2; offset >= 1; offset >>= 1) {
+ for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
}
}
- static_assert(c::nbatch_fa % (np*tile_C_KQ_16::J) == 0, "bad loop size");
+ static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size");
#pragma unroll
- for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) {
+ for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
#pragma unroll
- for (int t = 0; t < ntiles/2; ++t) {
-#pragma unroll
- for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
- const int KQ_index = 2*t + (l/2) % 2;
-
- KQ_C_16[k*ntiles/2 + t].x[l] = expf(KQ_C_16[k*ntiles/2 + t].x[l] - KQ_max_new[KQ_index]);
-
- KQ_rowsum_add[KQ_index] += KQ_C_16[k*ntiles/2 + t].x[l];
+ for (int l = 0; l < T_C_KQ::ne; ++l) {
+ // Turing + Volta:
+ if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
+ KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[(l/2) % 2]);
+ KQ_rowsum_add[(l/2) % 2] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
+ } else {
+ KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f;
}
}
}
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
}
- if (ntiles == 1) {
+#if defined(TURING_MMA_AVAILABLE)
+ if constexpr (cols_per_warp == 8) {
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
#pragma unroll
- for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
+ for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
#pragma unroll
- for (int l = 0; l < tile_C_VKQ::ne; ++l) {
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
VKQ_C[i].x[l] *= KQ_max_scale_h2;
}
}
for (int col = 0; col < cols_per_thread; ++col) {
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
#pragma unroll
- for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
+ for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
#pragma unroll
- for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
- VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
+ for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) {
+ VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2;
}
}
}
}
+#else // Volta
+ const half2 KQ_max_scale_h2 = make_half2(
+ KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]);
+#pragma unroll
+ for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
+#pragma unroll
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
+ VKQ_C[i].x[l] *= KQ_max_scale_h2;
+ }
+ }
+#endif // defined(TURING_MMA_AVAILABLE)
}
// Convert KQ C tiles into B tiles for VKQ calculation:
- tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles];
- tile_B_16 * B_16 = (tile_B_16 *) B;
- static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size");
- if (ntiles == 1) {
+ T_B_VKQ B[nbatch_fa/(np*2*T_B_VKQ::J)];
+ static_assert(nbatch_fa % (np*2*T_B_VKQ::J) == 0, "bad loop size");
+ if constexpr (cols_per_warp == 8) {
#pragma unroll
- for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) {
+ for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) {
B[k] = get_transposed(get_half2(KQ_C[k]));
}
} else {
- for (int k = 0; k < c::nbatch_fa/(np*2*tile_B_16::J); ++k) {
-#pragma unroll
- for (int t = 0; t < ntiles/2; ++t) {
- B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]);
- }
+ for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) {
+ B[k] = get_half2(KQ_C[k]);
}
}
- if (nstages > 1) {
+ if constexpr (nstages > 1) {
// Preload K tile for next iteration:
constexpr bool use_cp_async = true;
cp_async_wait_all();
__syncthreads();
if (!last_iter) {
- if (ncols2 > 1 || mask_h2) {
- flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
- (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
+ if (ncols2 > 1 || mask_h) {
+ flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
+ (mask_h + k_VKQ_0 + nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
}
- flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
- (K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
+ flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
+ (K_h2 + int64_t(k_VKQ_0 + nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
}
}
// Therefore, iterate over V in reverse and re-use the data if possible.
static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
+
+ // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
#pragma unroll
for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
const int i0_diff = i0_stop - i0_start;
- if (nstages <= 1 && i0_start < reusable_cutoff) {
- constexpr bool use_cp_async = nstages == 1;
- flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
- (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
- if (use_cp_async) {
- cp_async_wait_all();
+ if constexpr (nstages <= 1) {
+ if (i0_start < reusable_cutoff) {
+ constexpr bool use_cp_async = nstages == 1;
+ flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>
+ (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup);
+ if (use_cp_async) {
+ cp_async_wait_all();
+ }
+ __syncthreads();
}
- __syncthreads();
}
const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
- // Calculate VKQ tile:
+#if defined(TURING_MMA_AVAILABLE)
+ constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
#pragma unroll
- for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += tile_C_VKQ::I) {
- static_assert((c::nbatch_fa/2) % (np*tile_A::J) == 0, "bad loop size");
+ for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
+ static_assert((nbatch_fa/2) % (np*T_A_VKQ::J) == 0, "bad loop size");
#pragma unroll
- for (int k00 = 0; k00 < c::nbatch_fa/2; k00 += np*tile_A::J) {
- const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
+ for (int k00 = 0; k00 < nbatch_fa/2; k00 += np*T_A_VKQ::J) {
+ const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J;
- tile_A A;
+ T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
- if (ntiles == 1) {
- mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
+ if constexpr (T_B_KQ::I == 8) {
+ mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
} else {
-#pragma unroll
- for (int t = 0; t < ntiles/2; ++t) {
- // Wide version of VKQ_C is column-major => swap A and B.
- mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A);
- }
+ // Wide version of VKQ_C is column-major => swap A and B.
+ mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A);
}
}
}
+#else // Volta
+ constexpr int i0_stride = 2*T_C_VKQ::J;
+#pragma unroll
+ for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
+ static_assert(nbatch_fa % (np*T_A_VKQ::I) == 0, "bad loop size");
+ static_assert(2*T_B_VKQ::J == T_A_VKQ::I, "bad tile sizes");
+#pragma unroll
+ for (int k00 = 0; k00 < nbatch_fa; k00 += np*T_A_VKQ::I) {
+ const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::I;
+
+ T_A_VKQ A; // Transposed in both SRAM and registers, load normally.
+ load_ldmatrix(A, tile_V_i + k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
+ mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A);
+ }
+ }
+#endif // defined(TURING_MMA_AVAILABLE)
- if (nstages <= 1) {
+ if constexpr (nstages <= 1) {
__syncthreads(); // Only needed if tile_K == tile_V.
}
}
#else
- GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup,
+ GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup,
scale, slope, logit_softcap, ne01, ne02,
stride_K, stride_V, stride_mask,
tile_Q, tile_K, tile_V, tile_mask,
Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
NO_DEVICE_CODE;
-#endif // TURING_MMA_AVAILABLE
+#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
-template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
+#if defined(TURING_MMA_AVAILABLE)
+template<int ncols> struct mma_tile_sizes {
+ using T_A_KQ = tile<16, 8, half2>; // row-major
+ using T_B_KQ = tile<16, 8, half2>; // column-major
+ using T_C_KQ = tile<16, 16, float>; // column-major
+ using T_A_VKQ = tile<16, 8, half2>; // row-major
+ using T_B_VKQ = tile<16, 8, half2>; // column-major
+ using T_C_VKQ = tile<16, 8, half2>; // column-major
+};
+template<> struct mma_tile_sizes<8> {
+ using T_A_KQ = tile<16, 8, half2>; // row-major
+ using T_B_KQ = tile< 8, 8, half2>; // column-major
+ using T_C_KQ = tile<16, 8, float>; // row-major
+ using T_A_VKQ = tile<16, 8, half2>; // row-major
+ using T_B_VKQ = tile< 8, 8, half2>; // column-major
+ using T_C_VKQ = tile<16, 4, half2>; // row-major
+};
+#else // Volta
+template<int ncols> struct mma_tile_sizes {
+ using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
+ using T_B_KQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major
+ using T_C_KQ = tile<32, 8, float, DATA_LAYOUT_I_MAJOR>; // column-major
+ using T_A_VKQ = tile< 8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED>; // column-major
+ using T_B_VKQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major
+ using T_C_VKQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major
+};
+#endif // defined(TURING_MMA_AVAILABLE)
+
+template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const float2 * const __restrict__ Q_f2,
const half2 * const __restrict__ K_h2,
const half2 * const __restrict__ V_h2,
- const half2 * const __restrict__ mask_h2,
+ const half * const __restrict__ mask_h,
const float * const __restrict__ sinks_f,
float2 * const __restrict__ dstk,
float2 * const __restrict__ dstk_fixup,
const float scale,
const float slope,
const float logit_softcap,
- const int ne01,
+ const uint3 ne01,
const int ne02,
+ const int ne11,
const int stride_Q1,
const int stride_Q2,
const int stride_K,
const int jt,
const int kb0_start,
const int kb0_stop) {
-#ifdef TURING_MMA_AVAILABLE
+#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
- typedef fattn_mma_f16_config<DKQ, DV> c;
-
-#ifdef CP_ASYNC_AVAILABLE
- constexpr int nstages = c::nstages_target;
-#else
- constexpr int nstages = 0;
-#endif // CP_ASYNC_AVAILABLE
-
- constexpr int ncols = ncols1 * ncols2;
- constexpr int cols_per_warp = ntiles * tile_B::I;
- constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
- constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
- constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
+ constexpr int ncols = ncols1 * ncols2;
+ using T_A_KQ = typename mma_tile_sizes<ncols>::T_A_KQ;
+ using T_B_KQ = typename mma_tile_sizes<ncols>::T_B_KQ;
+ using T_C_KQ = typename mma_tile_sizes<ncols>::T_C_KQ;
+ using T_A_VKQ = typename mma_tile_sizes<ncols>::T_A_VKQ;
+ using T_B_VKQ = typename mma_tile_sizes<ncols>::T_B_VKQ;
+ using T_C_VKQ = typename mma_tile_sizes<ncols>::T_C_VKQ;
+
+ constexpr int cols_per_warp = T_B_KQ::I;
+ constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
+ constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+ constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
+ constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
+ constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
+ constexpr int nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols);
+ constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols);
+ constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2);
+
+ if (cols_per_warp > ncols) {
+ NO_DEVICE_CODE;
+ return;
+ }
static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
extern __shared__ half2 tile_Q[];
- half2 * tile_K = c::Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q;
- half2 * tile_V = nstages > 1 ? tile_K + c::nbatch_fa * stride_tile_K : tile_K;
- half2 * tile_mask = nstages > 1 ? tile_V + c::nbatch_fa * stride_tile_V : tile_V + c::nbatch_fa * stride_tile_KV_max;
-
- tile_B Q_B[(c::Q_in_reg ? DKQ/(2*tile_B::J) : 1) * ntiles];
- tile_C_VKQ VKQ_C[DV/tile_C_VKQ::I * ntiles];
+ half2 * tile_K = Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q;
+ half2 * tile_V = nstages > 1 ? tile_K + nbatch_fa * stride_tile_K : tile_K;
+ half * tile_mask = (half *) (nstages > 1 ? tile_V + nbatch_fa * stride_tile_V : tile_V + nbatch_fa * stride_tile_KV_max);
- tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B;
- tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
+ T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)];
+#if defined(TURING_MMA_AVAILABLE)
+ T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)];
+#else // Volta
+ T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
+#endif // defined(TURING_MMA_AVAILABLE)
float KQ_rowsum[cols_per_thread] = {0.0f};
float KQ_max[cols_per_thread];
const int j = jc / ncols2;
const int c = jc % ncols2;
- if (jt*ncols1 + j < ne01) {
+ if (jt*ncols1 + j < int(ne01.z)) {
#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
__syncthreads();
- if (c::Q_in_reg) {
+ if (Q_in_reg) {
const int j0 = (threadIdx.y / np) * cols_per_warp;
#pragma unroll
- for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) {
- if (ntiles == 1) {
- load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q);
- } else {
-#pragma unroll
- for (int t = 0; t < ntiles/2; ++t) {
- load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t],
- tile_Q + (j0 + t*tile_B_16::I)*stride_tile_Q + k0, stride_tile_Q);
- }
- }
+ for (int k0 = 0; k0 < DKQ/2; k0 += T_B_KQ::J) {
+ load_ldmatrix(Q_B[k0/T_B_KQ::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q);
}
}
__syncthreads();
+ int kb0 = kb0_start;
+
// Preload mask and K data for first iteration when using cp_async with multiple stages:
if constexpr (nstages > 1) {
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
constexpr bool use_cp_async = true;
- if (ncols2 > 1 || mask_h2) {
- flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
- (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
+ constexpr bool oob_check = false;
+ constexpr int k_VKQ_sup = nbatch_fa;
+ if (ncols2 > 1 || mask_h) {
+ flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
+ (mask_h + kb0*nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
}
- flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
- (K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
+ flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
+ (K_h2 + int64_t(kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
}
- // Iterate over ne11 == previous tokens:
- int kb0 = kb0_start;
for (; kb0 < kb0_stop-1; ++kb0) {
constexpr bool last_iter = false;
- flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
- (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
- ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
- }
- { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
+ constexpr bool oob_check = false;
+ constexpr int k_VKQ_sup = nbatch_fa;
+ flash_attn_ext_f16_iter
+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
+ T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
+ (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
+ KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
+ }
+ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
+ if constexpr (ncols2 == 1) {
+ if (ne11 % nbatch_fa == 0) {
+ constexpr bool last_iter = true;
+ constexpr bool oob_check = false;
+ constexpr int k_VKQ_sup = nbatch_fa;
+ flash_attn_ext_f16_iter
+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
+ T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
+ (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
+ KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
+ } else {
+ constexpr bool last_iter = true;
+ constexpr bool oob_check = true;
+ const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
+ flash_attn_ext_f16_iter
+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
+ T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
+ (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
+ KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
+ }
+ } else {
constexpr bool last_iter = true;
- flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
- (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
- ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
+ constexpr bool oob_check = false;
+ constexpr int k_VKQ_sup = nbatch_fa;
+ flash_attn_ext_f16_iter
+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
+ T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
+ (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
+ KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
}
// With multi-stage loading there is no __syncthreads at the end of the iter,
// there can be a race condition on shared memory access for combining/writing back results.
- if (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) {
+ if constexpr (nstages > 1 && nwarps*cols_per_warp > nbatch_fa) {
__syncthreads();
}
// Finally, sum up partial KQ rowsums.
- // The partial sums are spread across 8/4 threads each, does not need full reduce.
{
- constexpr int offset_first = ntiles == 1 ? 16 : 2;
- constexpr int offset_last = ntiles == 1 ? 4 : 1;
+#if defined(TURING_MMA_AVAILABLE)
+ // The partial sums are spread across 8/4 threads.
+ constexpr int offset_first = cols_per_warp == 8 ? 16 : 2;
+ constexpr int offset_last = cols_per_warp == 8 ? 4 : 1;
+#else // Volta
+ // The partial sums are spread across 2 threads.
+ constexpr int offset_first = 2;
+ constexpr int offset_last = 2;
+#endif // defined(TURING_MMA_AVAILABLE)
#pragma unroll
for (int col = 0; col < cols_per_thread; ++col) {
#pragma unroll
float KQ_max_scale[cols_per_thread];
#pragma unroll
for (int col = 0; col < cols_per_thread; ++col) {
- static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
- const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
+ const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col);
const float sink = sinks_f[jc % ncols2];
const float KQ_max_new = fmaxf(KQ_max[col], sink);
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
}
- if (ntiles == 1) {
+#if defined(TURING_MMA_AVAILABLE)
+ if constexpr (cols_per_warp == 8) {
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
#pragma unroll
- for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
+ for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
#pragma unroll
- for (int l = 0; l < tile_C_VKQ::ne; ++l) {
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
VKQ_C[i].x[l] *= KQ_max_scale_h2;
}
}
for (int col = 0; col < cols_per_thread; ++col) {
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
#pragma unroll
- for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
+ for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
#pragma unroll
- for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
- VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
+ for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) {
+ VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2;
}
}
}
}
+#else // Volta
+ const int col = (threadIdx.x / 2) % 2;
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
+#pragma unroll
+ for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
+#pragma unroll
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
+ VKQ_C[i].x[l] *= KQ_max_scale_h2;
+ }
+ }
+#endif // defined(TURING_MMA_AVAILABLE)
}
// Combine VKQ accumulator values if np > 1.
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
- constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols);
- constexpr int tile_stride = nbatch_combine + 4;
+ constexpr int tile_stride = nbatch_combine + 4;
static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
- if constexpr (ntiles == 1) {
- const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset
- const int jc_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta
+ if constexpr (cols_per_warp == 8) {
+ const int jc_cwmo = (threadIdx.x % (2*T_C_VKQ::J)) / T_C_VKQ::J; // jc combine write meta offset
+ const int jc_cwm = threadIdx.y*(2*T_C_VKQ::J) + 2*T_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta
const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum
- if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) {
+ if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*T_C_VKQ::J) {
// Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
}
if (np == 1) {
// No combination is needed, the meta data can be directly written from registers to VRAM.
- if (needs_fixup && threadIdx.x < tile_B::I) {
+ if (needs_fixup && threadIdx.x < T_B_KQ::I) {
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
dstk_fixup_meta[jc_cwm] = KQ_cmr;
}
- if (is_fixup && threadIdx.x < tile_B::I) {
+ if (is_fixup && threadIdx.x < T_B_KQ::I) {
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
dstk_fixup_meta[jc_cwm] = KQ_cmr;
}
}
} else {
- static_assert(ntiles == 2 || ntiles == 4, "bad ntiles");
- const int jc_cwm = threadIdx.y*cols_per_warp // jc combine write meta
- + (ntiles == 4 ? ((threadIdx.x % 4) / 2) * tile_C_VKQ_16::I : 0)
- + tile_C_VKQ_16::get_i(threadIdx.x % 4);
- const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); // KQ combine max rowsum
-
- if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) {
- // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
+ // jc_cwm = jc combine write meta
+ // KQ_cmr = KQ combine max rowsum
+ // Use the 16 bytes of padding in each Q column to store the meta data: KQ max, KQ rowsum, KQ max scale.
+#if defined(TURING_MMA_AVAILABLE)
+ const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4);
+ const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]);
+ const bool thread_should_write = threadIdx.x % 4 < cols_per_thread;
+#else // Volta
+ const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2);
+ const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]);
+ const bool thread_should_write = T_C_KQ::J == 8 || T_C_KQ::get_j(threadIdx.x & 2) < 8;
+#endif // defined(TURING_MMA_AVAILABLE)
+
+ if (((!needs_fixup && !is_fixup) || np > 1) && thread_should_write) {
((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
}
if (np == 1) {
// No combination is needed, the meta data can be directly written from registers to VRAM.
- if (needs_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) {
+ if (needs_fixup && thread_should_write) {
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
dstk_fixup_meta[jc_cwm] = KQ_cmr;
}
- if (is_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) {
+ if (is_fixup && thread_should_write) {
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
dstk_fixup_meta[jc_cwm] = KQ_cmr;
}
}
}
- static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles");
if (np > 1 && threadIdx.y % np == 0) {
// Combine the meta data for parallel warps via shared memory.
// Warps with threadIdx.y % np != 0 must NOT return early.
#pragma unroll
for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) {
- if (ntiles == 1) {
- const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data
+ if constexpr (cols_per_warp == 8) {
+ const int jc_cwd = threadIdx.y*T_B_KQ::I + T_B_KQ::get_i(-1); // jc combine write data
#pragma unroll
- for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) {
- const tile_B B = get_transposed(VKQ_C[(k00 + k0)/tile_B::J]); // Conversion of C to B matrix puts it in column-major format.
+ for (int k1 = 0; k1 < nbatch_combine; k1 += T_B_KQ::J) {
+ const T_B_KQ B = get_transposed(VKQ_C[(k00 + k1)/T_B_KQ::J]); // Conversion of C to B matrix puts it in column-major format.
#pragma unroll
- for (int l = 0; l < tile_B::ne; ++l) {
- const int k = k0 + tile_B::get_j(l);
+ for (int l = 0; l < T_B_KQ::ne; ++l) {
+ const int k = k1 + T_B_KQ::get_j(l);
tile_Q[jc_cwd*tile_stride + k] = B.x[l];
}
}
} else {
+ const int j0 = threadIdx.y*cols_per_warp;
#pragma unroll
- for (int t = 0; t < ntiles/2; ++t) {
- const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I;
+ for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) {
#pragma unroll
- for (int k0 = 0; k0 < nbatch_combine; k0 += tile_C_VKQ_16::J) {
-#pragma unroll
- for (int l = 0; l < tile_C_VKQ_16::ne; ++l) {
- const int j = j0 + tile_C_VKQ_16::get_i(l);
- const int k = k0 + tile_C_VKQ_16::get_j(l);
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
+ const int j = j0 + T_C_VKQ::get_i(l);
+ const int k = k1 + T_C_VKQ::get_j(l);
- tile_Q[j*tile_stride + k] = VKQ_C_16[(k00 + k0)/tile_C_VKQ_16::J * ntiles/2 + t].x[l];
- }
+ tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l];
}
}
}
const int j_dst = jc_dst / ncols2;
const int c_dst = jc_dst % ncols2;
- if (!is_fixup && jt*ncols1 + j_dst >= ne01) {
+ if (!is_fixup && jt*ncols1 + j_dst >= int(ne01.z)) {
continue;
}
}
}
#else
- GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dstk_fixup,
+ GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup,
scale, slope, logit_softcap, ne01, ne02,
stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
jt, kb0_start, kb0_stop);
NO_DEVICE_CODE;
-#endif // TURING_MMA_AVAILABLE
+#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
-template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
-__launch_bounds__(nwarps*WARP_SIZE, 1)
+template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool mla>
+__launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2))
static __global__ void flash_attn_ext_f16(
const char * __restrict__ Q,
const char * __restrict__ K,
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
- const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+ const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
-#if defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
+#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
- typedef fattn_mma_f16_config<DKQ, DV> c;
-
- static_assert(FATTN_KQ_STRIDE % fattn_mma_f16_config<DKQ, DV>::nbatch_fa == 0, "bad nbatch_fa");
+ constexpr int ncols = ncols1 * ncols2;
+ constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
+ constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols);
+ constexpr int nwarps = nthreads / WARP_SIZE;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const int stride_Q1 = nb01 / sizeof(float2);
const int stride_Q2 = nb02 / sizeof(float2);
const int stride_K = nb11 / sizeof(half2);
- const int stride_mask = nb31 / sizeof(half2);
+ const int stride_mask = nb31 / sizeof(half);
const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
- const int iter_k = ne11 / FATTN_KQ_STRIDE;
- const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
-
- constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
+ const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
+ const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
// kbc == k block continuous, current index in continuous ijk space.
int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
const int head0 = zt * ncols2;
- const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
- const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
- const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
- (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
- float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
+ const half * mask_h = ncols2 == 1 && !mask ? nullptr :
+ (const half *) (mask + nb33*(sequence % ne33));
+ float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
- const int kb0_start_kernel = kb0_start * kb_niter;
- int kb0_stop_kernel = kb0_stop * kb_niter;
-
if (KV_max) {
- kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa);
+ kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
}
-
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
if (kb0_start == 0) {
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
- (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
- ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
+ (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
+ ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
} else {
- constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
- (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
- ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
+ constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
+ (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
+ ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
}
kbc += iter_k;
const int head0 = zt * ncols2;
- const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
- const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
- const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
- (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
- float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
+ const half * mask_h = ncols2 == 1 && !mask ? nullptr :
+ (const half *) (mask + nb33*(sequence % ne33));
+ float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
- const int kb0_start_kernel = kb0_start * kb_niter;
- int kb0_stop_kernel = kb0_stop * kb_niter;
-
if (KV_max) {
- kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa);
+ kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
}
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
constexpr bool needs_fixup = false;
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
- (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
- ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
+ (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
+ ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
#else
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,
ne31, ne32, ne33,
nb31, nb32, nb33);
NO_DEVICE_CODE;
-#endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
+#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
}
template <int DKQ, int DV, int ncols1, int ncols2>
const int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc;
- typedef fattn_mma_f16_config<DKQ, DV> c;
+ constexpr int ncols = ncols1 * ncols2;
- const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
+ const int nthreads = ggml_cuda_fattn_mma_get_nthreads (DKQ, DV, ncols, cc);
+ const int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols, cc);
+ const int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols, cc);
+ const int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols, cc);
+ const int nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols, cc);
+ const bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols, cc);
+ const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc);
- constexpr int ncols = ncols1 * ncols2;
- constexpr int ntiles = ncols <= 8 ? 1 : 2; // Number of tiles per warp.
- constexpr int cols_per_warp = ntiles * tile_B::I;
- constexpr int nwarps_max_x = ncols / cols_per_warp;
- constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
- constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
+ const int cols_per_warp = std::min(ncols, turing_mma_available(cc) ? 16 : 32);
+ const int nwarps = nthreads / WARP_SIZE;
constexpr bool mla = DKQ == 576;
- const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols);
- const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols);
- const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols);
-
- static_assert(DKQ % tile_B::J == 0, "bad DKQ");
- static_assert(DV % tile_A::J == 0, "bad DV");
- static_assert(ncols % cols_per_warp == 0, "bad ncols");
-
- const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
- const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
+ const size_t nbytes_shared_KV_1stage = nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
+ const size_t nbytes_shared_KV_2stage = nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
- const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
+ const size_t nbytes_shared_mask = ncols1 * (nbatch_fa/2 + 4) * sizeof(half2);
const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
- const size_t nbytes_shared_total = std::max(nbytes_shared_combine, c::Q_in_reg ?
+ const size_t nbytes_shared_total = std::max(nbytes_shared_combine, Q_in_reg ?
std::max(nbytes_shared_Q, nbytes_shared_KV + nbytes_shared_mask) :
nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask);
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
- fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
+ fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
} else {
constexpr bool use_logit_softcap = true;
- fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
+ fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
}
launch_fattn<DV, ncols1, ncols2>
- (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true);
+ (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true);
}
namespace ggml_cuda_mma {
+ // Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel,
+ // effectively the warp is being split into subgroups of threads that each perform a single mma instruction.
+ // In those cases the data can be split in different ways across the warp.
+ enum data_layout {
+ // By default the data uses the I direction as its major dimension and the J direction as its minor dimension.
+ // For the A/C matrices this means I major == row major, J major == column major.
+ // For the B matrix this means I major == column major, J major == row major.
+ // MIRRORED == Each data value is held exactly once per thread subgroup.
+ DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell.
+ DATA_LAYOUT_I_MAJOR_MIRRORED = 10,
+ DATA_LAYOUT_J_MAJOR_MIRRORED = 20,
+ };
+ // Implemented mma combinations are:
+ // - (I_MAJOR, I_MAJOR) -> I_MAJOR
+ // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
+ // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
+
+ template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
+ struct tile {};
+
template <int I_, int J_, typename T>
- struct tile {
- static constexpr int I = I_;
- static constexpr int J = J_;
+ struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR> {
+ static constexpr int I = I_;
+ static constexpr int J = J_;
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
#if defined(AMD_MFMA_AVAILABLE)
static constexpr int ne = I * J / 64;
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 32 && J == 8) {
#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
- return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (l & 2) | (threadIdx.x % 2);
+ return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (l & 2) + (threadIdx.x % 2);
#else
- return (l & 2) | (threadIdx.x & ~2);
+ return (l & 2) + (threadIdx.x & ~2);
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
} else {
NO_DEVICE_CODE;
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 32 && J == 8) {
- return (threadIdx.x & 2) | (l & (4 + 1));
+ return (threadIdx.x & 2) + (l & (4 + 1));
} else {
NO_DEVICE_CODE;
return -1;
} else if constexpr (I == 8 && J == 8) {
return threadIdx.x / 4;
} else if constexpr (I == 16 && J == 8) {
- return ((l / 2) * 8) | (threadIdx.x / 4);
+ return ((l / 2) * 8) + (threadIdx.x / 4);
} else if constexpr (I == 16 && J == 16) {
- return (((l / 2) % 2) * 8) | (threadIdx.x / 4);
+ return (((l / 2) % 2) * 8) + (threadIdx.x / 4);
} else if constexpr (I == 32 && J == 8) {
return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction.
} else {
if constexpr (I == 8 && J == 4) {
return threadIdx.x % 4;
} else if constexpr (I == 8 && J == 8) {
- return (l * 4) | (threadIdx.x % 4);
+ return (l * 4) + (threadIdx.x % 4);
} else if constexpr (I == 16 && J == 8) {
- return ((threadIdx.x % 4) * 2) | (l % 2);
+ return ((threadIdx.x % 4) * 2) + (l % 2);
} else if constexpr (I == 16 && J == 16) {
- return ((l / 4) * 8) | ((threadIdx.x % 4) * 2) | (l % 2);
+ return ((l / 4) * 8) + ((threadIdx.x % 4) * 2) + (l % 2);
} else if constexpr (I == 32 && J == 8) {
return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction.
} else {
};
template <int I_, int J_>
- struct tile<I_, J_, half2> {
- static constexpr int I = I_;
- static constexpr int J = J_;
+ struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR> {
+ static constexpr int I = I_;
+ static constexpr int J = J_;
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
- static constexpr int ne = I == 8 && J == 8 ? I * J / (WARP_SIZE/4) : I * J / WARP_SIZE;
+ static constexpr int ne = I * J / WARP_SIZE;
half2 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
- if (I == 8 && J == 8) return true;
- if (I == 32 && J == 8) return true;
+ if (I == 32 && J == 4) return true;
return false;
}
static __device__ __forceinline__ int get_i(const int l) {
- if constexpr (I == 8 && J == 8) {
- return ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
- } else if constexpr (I == 32 && J == 8) {
+ if constexpr (I == 32 && J == 4) {
#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
- return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
+ return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
#else
return threadIdx.x;
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
}
static __device__ __forceinline__ int get_j(const int l) {
- if constexpr ((I == 8 || I == 32) && J == 8) {
+ if constexpr (I == 32 && J == 4) {
return l;
} else {
NO_DEVICE_CODE;
if constexpr (I == 8 && J == 8) {
return threadIdx.x / 4;
} else if constexpr (I == 16 && J == 4) {
- return (l * 8) | (threadIdx.x / 4);
+ return (l * 8) + (threadIdx.x / 4);
} else if constexpr (I == 16 && J == 8) {
- return ((l % 2) * 8) | (threadIdx.x / 4);
+ return ((l % 2) * 8) + (threadIdx.x / 4);
} else if constexpr (I == 32 && J == 8) {
- return ((l / 4) * 16) | ((l % 2) * 8) | (threadIdx.x / 4);
+ return ((l / 4) * 16) + ((l % 2) * 8) + (threadIdx.x / 4);
} else {
NO_DEVICE_CODE;
return -1;
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 8 && J == 8) {
- return (l * 4) | (threadIdx.x % 4);
+ return (l * 4) + (threadIdx.x % 4);
} else if constexpr (I == 16 && J == 4) {
return threadIdx.x % 4;
} else if constexpr (I == 16 && J == 8) {
- return ((l / 2) * 4) | (threadIdx.x % 4);
+ return ((l / 2) * 4) + (threadIdx.x % 4);
} else if constexpr (I == 32 && J == 8) {
- return ((l & 2) * 2) | (threadIdx.x % 4);
+ return ((l & 2) * 2) + (threadIdx.x % 4);
} else {
NO_DEVICE_CODE;
return -1;
};
template <int I_, int J_>
- struct tile<I_, J_, nv_bfloat162> {
- static constexpr int I = I_;
- static constexpr int J = J_;
+ struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR> {
+ static constexpr int I = I_;
+ static constexpr int J = J_;
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
+ static constexpr int ne = I * J / WARP_SIZE;
-#if defined(AMD_WMMA_AVAILABLE)
- static constexpr int ne = I * J / 32;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
+#if defined(AMD_WMMA_AVAILABLE)
static constexpr __device__ bool supported() {
if (I == 16 && J == 8) return true;
return false;
}
}
#else
- static constexpr int ne = I * J / WARP_SIZE;
- nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
-
static constexpr __device__ bool supported() {
if (I == 8 && J == 8) return true;
if (I == 16 && J == 4) return true;
if constexpr (I == 8 && J == 8) {
return threadIdx.x / 4;
} else if constexpr (I == 16 && J == 4) {
- return (l * 8) | (threadIdx.x / 4);
+ return (l * 8) + (threadIdx.x / 4);
} else if constexpr (I == 16 && J == 8) {
- return ((l % 2) * 8) | (threadIdx.x / 4);
+ return ((l % 2) * 8) + (threadIdx.x / 4);
} else {
NO_DEVICE_CODE;
return -1;
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 8 && J == 8) {
- return (l * 4) | (threadIdx.x % 4);
+ return (l * 4) + (threadIdx.x % 4);
} else if constexpr (I == 16 && J == 4) {
return threadIdx.x % 4;
} else if constexpr (I == 16 && J == 8) {
- return ((l / 2) * 4) | (threadIdx.x % 4);
+ return ((l / 2) * 4) + (threadIdx.x % 4);
} else {
NO_DEVICE_CODE;
return -1;
#endif // defined(AMD_WMMA_AVAILABLE)
};
+ template <int I_, int J_>
+ struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
+ static constexpr int I = I_;
+ static constexpr int J = J_;
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
+ static constexpr int ne = I * J / (WARP_SIZE/4);
+
+ half2 x[ne] = {{0.0f, 0.0f}};
+
+ static constexpr __device__ bool supported() {
+ if (I == 8 && J == 4) return true;
+ return false;
+ }
+
+ static __device__ __forceinline__ int get_i(const int /*l*/) {
+ if constexpr (I == 8 && J == 4) {
+ return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+
+ static __device__ __forceinline__ int get_j(const int l) {
+ if constexpr (I == 8 && J == 4) {
+ return l;
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+ };
+
+ template <int I_, int J_>
+ struct tile<I_, J_, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> {
+ static constexpr int I = I_;
+ static constexpr int J = J_;
+ static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED;
+ static constexpr int ne = I * J / (WARP_SIZE/4);
+
+ half2 x[ne] = {{0.0f, 0.0f}};
+
+ static constexpr __device__ bool supported() {
+ if (I == 8 && J == 4) return true;
+ return false;
+ }
+
+ static __device__ __forceinline__ int get_i(const int l) {
+ if constexpr (I == 8 && J == 4) {
+ return ((l / 2) * 4) + (threadIdx.x % 4);
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+
+ static __device__ __forceinline__ int get_j(const int l) {
+ if constexpr (I == 8 && J == 4) {
+ return ((threadIdx.x / 16) * 2) + (l % 2);
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+ };
+
+#if defined(TURING_MMA_AVAILABLE)
template <int I, int J>
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
tile<I, J/2, half2> ret;
return ret;
}
+#else // Volta
+ template <int I, int J>
+ static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
+ tile<I, J/2, half2> ret;
+#pragma unroll
+ for (int l0 = 0; l0 < tile_float.ne; l0 += 4) {
+ ret.x[l0/2 + 0] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
+ ret.x[l0/2 + 1] = make_half2(tile_float.x[l0 + 2], tile_float.x[l0 + 3]);
+
+ // On Volta FP16 and FP32 tiles have a different memory layout,
+ // for the conversion threads with an offset of 2 need to exchange half their values:
+ ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)] = __shfl_xor_sync(
+ 0xFFFFFFFF, ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)], 2, WARP_SIZE);
+ }
+ return ret;
+ }
+#endif // defined(TURING_MMA_AVAILABLE)
- template <int I, int J, typename T>
- static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
+ template <int I, int J, typename T, data_layout dl>
+ static __device__ __forceinline__ void load_generic(tile<I, J, T, dl> & t, const T * __restrict__ xs0, const int stride) {
#if defined(AMD_MFMA_AVAILABLE)
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
#pragma unroll
: "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
: "l"(xs));
#else
-#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
- GGML_UNUSED_VARS(t, xs0, stride);
- NO_DEVICE_CODE;
-#else
- load_generic(t, xs0, stride);
-#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
-#endif // TURING_MMA_AVAILABLE
- }
-
- template <typename T>
- static __device__ __forceinline__ void load_ldmatrix(
- tile<32, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
#if 1
// TODO: more generic handling
load_generic(t, xs0, stride);
#endif // 1
#else
- tile<16, 8, T> * t16 = (tile<16, 8, T> *) &t;
- load_ldmatrix(t16[0], xs0 + 0*stride, stride);
- load_ldmatrix(t16[1], xs0 + 16*stride, stride);
+ load_generic(t, xs0, stride);
+#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+#endif // TURING_MMA_AVAILABLE
+ }
+
+ static __device__ __forceinline__ void load_ldmatrix(
+ tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
+ ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
+ }
+
+ static __device__ __forceinline__ void load_ldmatrix(
+ tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
+#pragma unroll
+ for (int l0 = 0; l0 < t.ne; l0 += 2) {
+ ggml_cuda_memcpy_1<2*sizeof(half2)>(t.x + l0, xs0 + t.get_i(l0)*stride + t.get_j(l0));
+ }
+ }
+
+ static __device__ __forceinline__ void load_ldmatrix(
+ tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) {
+#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+ ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
+#else
+ GGML_UNUSED_VARS(t, xs0, stride);
+ NO_DEVICE_CODE;
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
}
template <typename T1, typename T2, int J, int K>
static __device__ __forceinline__ void mma(
tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile<J, K, T2> & B) {
- tile<16, J, T1> * D16 = (tile<16, J, T1> *) &D;
- tile<16, K, T2> * A16 = (tile<16, K, T2> *) &A;
+ tile <16, J, T1> * D16 = reinterpret_cast< tile<16, J, T1> *>(&D);
+ const tile<16, K, T2> * A16 = reinterpret_cast<const tile<16, K, T2> *>(&A);
mma(D16[0], A16[0], B);
mma(D16[1], A16[1], B);
}
static __device__ __forceinline__ void mma(
- tile<32, 8, float> & D, const tile<32, 8, half2> & A, const tile<8, 8, half2> & B) {
+ tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) {
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
- asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
- "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
- : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
- : "r"(Axi[4]), "r"(Axi[5]), "r"(Bxi[4]), "r"(Bxi[5]));
- asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
- "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
- : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
- : "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7]));
#else
- tile <16, 8, float> * D16 = reinterpret_cast<tile <16, 8, float> *>(&D);
- const tile<16, 8, half2> * A16 = reinterpret_cast<const tile<16, 8, half2> *>(&A);
- mma(D16[0], A16[0], B);
- mma(D16[1], A16[1], B);
-#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+ GGML_UNUSED_VARS(D, A, B);
+ NO_DEVICE_CODE;
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+ }
+
+ static __device__ __forceinline__ void mma(
+ tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) {
+#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+ const int * Axi = (const int *) A.x;
+ const int * Bxi = (const int *) B.x;
+ int * Dxi = (int *) D.x;
+ asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
+ "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
+ asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
+ "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+ : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
+#else
+ GGML_UNUSED_VARS(D, A, B);
+ NO_DEVICE_CODE;
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
}
static __device__ __forceinline__ void mma(