#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
-#define GGML_CUDA_CC_PASCAL 600
-#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
-#define GGML_CUDA_CC_VOLTA 700
-#define GGML_CUDA_CC_TURING 750
-#define GGML_CUDA_CC_AMPERE 800
-#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
+#define GGML_CUDA_CC_PASCAL 600
+#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
+#define GGML_CUDA_CC_VOLTA 700
+#define GGML_CUDA_CC_TURING 750
+#define GGML_CUDA_CC_AMPERE 800
+#define GGML_CUDA_CC_ADA_LOVELACE 890
+#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
// GCN/CNDA, wave size is 64
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
#define NEW_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#define CP_ASYNC_AVAILABLE
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
#define FLASH_ATTN_AVAILABLE
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
}
+static bool cp_async_available(const int cc) {
+ return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
+}
+
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
return __AMDGCN_WAVEFRONT_SIZE;
--- /dev/null
+// Simplified API for asynchronous data loading.
+
+#include "common.cuh"
+
+// Copies data from global to shared memory, cg == cache global.
+// Both the src and dst pointers must be aligned to 16 bit.
+// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
+// Generic pointers can be converted to 32 bit shared memory pointers using __cvta_generic_to_shared.
+// Only the 16 bit copy is exposed because 4 and 8 bit copies did not yield performance improvements.
+template <int preload>
+static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, const void * src) {
+ static_assert(preload == 0 || preload == 64 || preload == 128 || preload == 256, "bad preload");
+#ifdef CP_ASYNC_AVAILABLE
+#if CUDART_VERSION >= 11040
+ if (preload == 256) {
+ asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16;"
+ : : "r"(dst), "l"(src));
+ } else if (preload == 128) {
+ asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16;"
+ : : "r"(dst), "l"(src));
+ } else if (preload == 64) {
+ asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16;"
+ : : "r"(dst), "l"(src));
+ } else
+#endif // CUDART_VERSION >= 11040
+ {
+ asm volatile("cp.async.cg.shared.global.L2 [%0], [%1], 16;"
+ : : "r"(dst), "l"(src));
+ }
+#else
+ GGML_UNUSED(dst);
+ GGML_UNUSED(src);
+ NO_DEVICE_CODE;
+#endif // CP_ASYNC_AVAILABLE
+}
+
+// Makes each thread wait until its asynchronous data copies are done.
+// This does NOT provide any additional synchronization.
+// In particular, when copying data with multiple warps a call to __syncthreads will be needed.
+static __device__ __forceinline__ void cp_async_wait_all() {
+#ifdef CP_ASYNC_AVAILABLE
+ asm volatile("cp.async.wait_all;");
+#else
+ NO_DEVICE_CODE;
+#endif // CP_ASYNC_AVAILABLE
+}
ggml_cuda_pool & pool = ctx.pool();
cudaStream_t main_stream = ctx.stream();
- const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
+ const int id = ggml_cuda_get_device();
+ const int cc = ggml_cuda_info().devices[id].cc;
+ const int nsm = ggml_cuda_info().devices[id].nsm;
ggml_cuda_pool_alloc<half> K_f16(pool);
ggml_cuda_pool_alloc<half> V_f16(pool);
dim3 blocks_num;
if (parallel_blocks == 0) {
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
- const int tiles_nwaves = (ntiles_total - nsm - 1) / nsm;
- const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total;
- const bool short_context = K->ne[1] < 4096;
+ const int tiles_nwaves = (ntiles_total + 2*nsm - 1) / (2*nsm);
+ const int tiles_efficiency_percent = 100 * ntiles_total / (2*nsm*tiles_nwaves);
const int nblocks_stream_k = 2*nsm;
- blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k;
+ const bool use_stream_k = tiles_efficiency_percent < 75 || cc >= GGML_CUDA_CC_ADA_LOVELACE;
+
+ blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
blocks_num.y = 1;
blocks_num.z = 1;
CUDA_CHECK(cudaGetLastError());
if constexpr (parallel_blocks == 0) {
- if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles.
+ if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine = blocks_num;
#include "common.cuh"
+#include "cp-async.cuh"
#include "mma.cuh"
#include "fattn-common.cuh"
-template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
-static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
- const float2 * const __restrict__ Q_f2,
- const half2 * const __restrict__ K_h2,
- const half2 * const __restrict__ V_h2,
- const half * const __restrict__ maskh,
- float2 * const __restrict__ dstk,
- float2 * const __restrict__ dstk_fixup,
- const float scale,
- const float slope,
- const float logit_softcap,
- const int ne00,
- const int ne01,
- const int ne02,
- const int ne03,
- const int ne10,
- const int ne11,
- const int ne12,
- const int ne13,
- const int ne31,
- const int nb31,
- const int nb01,
- const int nb02,
- const int nb03,
- const int nb11,
- const int nb12,
- const int nb13,
- const int nb21,
- const int nb22,
- const int nb23,
- const int ne0,
- const int ne1,
- const int ne2,
- const int ne3,
- const int jt,
- const int kb0_start,
- const int kb0_stop) {
-#ifdef NEW_MMA_AVAILABLE
- //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+using namespace ggml_cuda_mma;
- typedef mma_A_I16K8<half2> mma_A;
- typedef mma_B_J8K8<half2> mma_B;
- typedef mma_C_I16J8<float> mma_C_KQ;
- typedef mma_C_I16J8<half2> mma_C_VKQ;
-
- static_assert(nwarps*mma_B::J % ncols == 0, "bad nwarps");
- constexpr int np = nwarps*mma_B::J / ncols; // Number of parallel CUDA warps per Q column.
-
- static_assert(D % nwarps == 0, "bad D");
- static_assert(KQ_stride % nwarps == 0, "bad KQ_stride");
+typedef tile<16, 8, half2> tile_A;
+typedef tile< 8, 8, half2> tile_B;
+typedef tile<16, 8, float> tile_C_KQ;
+typedef tile<16, 4, half2> tile_C_VKQ;
+template<int D, int nwarps, int KQ_stride>
+static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
+ const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) {
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
- extern __shared__ half2 tile_KV[]; // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements.
- const int stride_Q = nb01 / sizeof(float2);
- const int stride_KV = nb11 / sizeof(half2);
- const int stride_mask = nb31 / sizeof(half);
+ // If cp.async is available, load up to the highest power of 2 in D asynchronously:
+#ifdef CP_ASYNC_AVAILABLE
+ static_assert(D >= 64 && D < 512, "bad D");
+ constexpr int k0_sync_start = D/2 < 64 ? 32 : (D/2 < 128 ? 64 : 128);
- mma_B Q_B[D/(2*mma_B::K)];
- mma_C_VKQ VKQ_C[D/mma_C_VKQ::I];
+ const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV);
- float2 KQ_rowsum = {0.0f, 0.0f};
- float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f};
- float2 KQ_max_scale = {0.0f, 0.0f};
+ constexpr int preload = 64;
+ constexpr int h2_per_chunk = 16/sizeof(half2);
+ constexpr int chunks_per_row = k0_sync_start / h2_per_chunk;
+ constexpr int stride_i = WARP_SIZE / chunks_per_row;
+#pragma unroll
+ for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) {
+ const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row);
+ const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk;
- // Temporarily load Q data into tile_KV, will be loaded into registers afterwards.
- // The loading is done with decreasing granularity for D for better memory bandwidth.
- const half2 scale_h2 = make_half2(scale, scale);
+ cp_async_cg_16<preload>(tile_KV_32 + (i*D2_padded + k)*sizeof(half2), KV + i*stride_KV + k);
+ }
+#else
+ constexpr int k0_sync_start = 0;
+#endif // CP_ASYNC_AVAILABLE
+ static_assert(k0_sync_start % WARP_SIZE == 0, "bad k0_sync_start");
+
+ // If D is not a power of 2, the rest is loaded synchronously.
+ // K/V data is loaded with decreasing granularity for D for better memory bandwidth.
+ static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
#pragma unroll
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
- const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
- const int k0_stop = D/2 - (D/2) % (1*stride_k);
- const int stride_j = WARP_SIZE / stride_k;
+ const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k);
+ const int k0_stop = D/2 - (D/2) % (1*stride_k);
+ const int stride_i = WARP_SIZE / stride_k;
- if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
- break;
+ if (k0_start == k0_stop || k0_stop <= k0_sync_start) {
+ continue;
}
#pragma unroll
- for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) {
- const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+ for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) {
+ const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
- if (jt*ncols + j < ne01) {
-#pragma unroll
- for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
- const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
-
- const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k];
- tile_KV[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
- }
- } else {
#pragma unroll
- for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
- const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+ for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+ const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
- tile_KV[j*D2_padded + k] = make_half2(0.0f, 0.0f);
- }
+ tile_KV[i*D2_padded + k] = KV[i*stride_KV + k];
}
}
}
+}
- __syncthreads();
-
- {
- const int j0 = (threadIdx.y / np) * mma_B::J;
+template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
+static __device__ __forceinline__ void flash_attn_ext_f16_iter(
+ const float2 * const __restrict__ Q_f2,
+ const half2 * const __restrict__ K_h2,
+ const half2 * const __restrict__ V_h2,
+ const half * const __restrict__ maskh,
+ float2 * const __restrict__ dstk,
+ float2 * const __restrict__ dstk_fixup,
+ const float scale,
+ const float slope,
+ const float logit_softcap,
+ const int ne01,
+ const int ne02,
+ const int stride_Q,
+ const int stride_KV,
+ const int stride_mask,
+ const int jt,
+ half2 * const __restrict__ tile_K,
+ half2 * const __restrict__ tile_V,
+ const tile_B * const __restrict__ Q_B,
+ tile_C_VKQ * const __restrict__ VKQ_C,
+ float2 & KQ_max,
+ float2 & KQ_rowsum,
+ const int kb0) {
+#ifdef NEW_MMA_AVAILABLE
+ constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column.
+ constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
-#pragma unroll
- for (int k0 = 0; k0 < D/2; k0 += mma_B::K) {
- Q_B[k0/mma_B::K].load_ldmatrix(tile_KV + j0*D2_padded + k0, D2_padded);
- }
- }
+ const int k_VKQ_0 = kb0*KQ_stride;
+ tile_C_KQ KQ_C[KQ_stride/(np*tile_C_KQ::I)];
+#ifdef CP_ASYNC_AVAILABLE
+ cp_async_wait_all();
__syncthreads();
+ flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
+#else
+ flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV);
+ __syncthreads();
+#endif // CP_ASYNC_AVAILABLE
- // Iterate over ne11 == previous tokens:
- for (int kb0 = kb0_start; kb0 < kb0_stop; ++kb0) {
- const int k_VKQ_0 = kb0*KQ_stride;
- mma_C_KQ KQ_C[KQ_stride/(np*mma_C_KQ::I)];
-
- // Load K data into tile with decreasing granularity for D for better memory bandwidth:
- static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
-#pragma unroll
- for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
- const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
- const int k0_stop = D/2 - (D/2) % (1*stride_k);
- const int stride_i = WARP_SIZE / stride_k;
-
-#pragma unroll
- for (int i_KQ_0 = 0; i_KQ_0 < KQ_stride; i_KQ_0 += nwarps*stride_i) {
- const int i_KQ = i_KQ_0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
-
-#pragma unroll
- for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += stride_k) {
- const int k_KQ = k_KQ_0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
-
- tile_KV[i_KQ*D2_padded + k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV + k_KQ];
- }
- }
- }
-
- __syncthreads();
-
- // Calculate tile of KQ:
+ // Calculate tile of KQ:
#pragma unroll
- for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*mma_A::I) {
- const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*mma_A::I;
+ for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*tile_A::I) {
+ const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
#pragma unroll
- for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += mma_A::K) {
- mma_A K_A;
- K_A.load_ldmatrix(tile_KV + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
- KQ_C[i_KQ_00/(np*mma_A::I)].mma(K_A, Q_B[k_KQ_0/mma_A::K]);
- }
+ for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) {
+ tile_A K_A;
+ load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
+ mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, ((tile_B *) Q_B)[k_KQ_0/tile_A::J]);
}
+ }
- __syncthreads();
+#ifndef CP_ASYNC_AVAILABLE
+ __syncthreads(); // Only needed if tile_K == tile_V.
+#endif // CP_ASYNC_AVAILABLE
- if (use_logit_softcap) {
- static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
+ if (use_logit_softcap) {
+ static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
#pragma unroll
- for (int i = 0; i < KQ_stride/(np*mma_C_KQ::I); ++i) {
+ for (int i = 0; i < KQ_stride/(np*tile_C_KQ::I); ++i) {
#pragma unroll
- for (int l = 0; l < mma_C_KQ::ne; ++l) {
- KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
- }
+ for (int l = 0; l < tile_C_KQ::ne; ++l) {
+ KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
}
}
+ }
- if (maskh) {
- static_assert(KQ_stride % (np *mma_C_KQ::I) == 0, "bad loop size");
- static_assert(ncols % (nwarps/np*mma_C_KQ::J) == 0, "bad loop size");
+ if (maskh) {
+ static_assert(KQ_stride % (np *tile_C_KQ::I) == 0, "bad loop size");
+ static_assert(ncols % (nwarps/np*tile_C_KQ::J) == 0, "bad loop size");
#pragma unroll
- for (int i00 = 0; i00 < KQ_stride; i00 += np*mma_C_KQ::I) {
- const int i0 = i00 + (threadIdx.y % np)*mma_C_KQ::I;
+ for (int i00 = 0; i00 < KQ_stride; i00 += np*tile_C_KQ::I) {
+ const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
#pragma unroll
- for (int l = 0; l < mma_C_KQ::ne; ++l) {
- const int i = i0 + mma_C_KQ::get_i(l);
- const int j = (threadIdx.y / np)*mma_C_KQ::J + mma_C_KQ::get_j(l);
+ for (int l = 0; l < tile_C_KQ::ne; ++l) {
+ const int i = i0 + tile_C_KQ::get_i(l);
+ const int j = (threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l);
- KQ_C[i00/(np*mma_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]);
- }
+ KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]);
}
}
+ }
- // Calculate softmax for each KQ column using the current max. value.
- // The divisor is stored in KQ_rowsum and will be applied at the end.
- float2 KQ_max_new = KQ_max;
- static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
+ // Calculate softmax for each KQ column using the current max. value.
+ // The divisor is stored in KQ_rowsum and will be applied at the end.
+ float2 KQ_max_new = KQ_max;
+ static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
#pragma unroll
- for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) {
+ for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) {
#pragma unroll
- for (int l0 = 0; l0 < mma_C_KQ::ne; l0 += 2) {
- KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]);
- KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]);
- }
+ for (int l0 = 0; l0 < tile_C_KQ::ne; l0 += 2) {
+ KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]);
+ KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]);
}
+ }
- // Values per KQ column are spread across 8 threads, does not need full warp reduce:
+ // Values per KQ column are spread across 8 threads, does not need full warp reduce:
#pragma unroll
- for (int offset = 16; offset > 2; offset >>= 1) {
- KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE));
- KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE));
- }
-
- {
- const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y);
- KQ_max_scale = make_float2(expf(diff.x), expf(diff.y));
- if (diff.x <= SOFTMAX_FTZ_THRESHOLD) {
- KQ_max_scale.x = 0.0f;
- }
- if (diff.y <= SOFTMAX_FTZ_THRESHOLD) {
- KQ_max_scale.y = 0.0f;
- }
- KQ_max = KQ_max_new;
- }
+ for (int offset = 16; offset > 2; offset >>= 1) {
+ KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE));
+ KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE));
+ }
- float2 KQ_rowsum_add = make_float2(0.0f, 0.0f);
- static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
+ float2 KQ_rowsum_add = make_float2(0.0f, 0.0f);
+ static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
#pragma unroll
- for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) {
+ for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) {
#pragma unroll
- for (int l = 0; l < mma_C_KQ::ne; ++l) {
- const float KQ_max_l = l % 2 == 0 ? KQ_max.x : KQ_max.y;
- const float diff = KQ_C[k].x[l] - KQ_max_l;
- KQ_C[k].x[l] = expf(diff);
- if (diff <= SOFTMAX_FTZ_THRESHOLD) {
- KQ_C[k].x[l] = 0.0f;
- }
+ for (int l = 0; l < tile_C_KQ::ne; ++l) {
+ const float KQ_max_l = l % 2 == 0 ? KQ_max_new.x : KQ_max_new.y;
+ const float diff = KQ_C[k].x[l] - KQ_max_l;
+ KQ_C[k].x[l] = expf(diff);
- if (l % 2 == 0) {
- KQ_rowsum_add.x += KQ_C[k].x[l];
- } else {
- KQ_rowsum_add.y += KQ_C[k].x[l];
- }
+ if (l % 2 == 0) {
+ KQ_rowsum_add.x += KQ_C[k].x[l];
+ } else {
+ KQ_rowsum_add.y += KQ_C[k].x[l];
}
}
+ }
+
+ {
+ const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y);
+ const float2 KQ_max_scale = make_float2(expf(diff.x), expf(diff.y));
+ KQ_max = KQ_max_new;
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x;
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y);
#pragma unroll
- for (int i = 0; i < D/mma_C_VKQ::I; ++i) {
+ for (int i = 0; i < D/tile_C_VKQ::I; ++i) {
#pragma unroll
- for (int l = 0; l < mma_C_VKQ::ne; ++l) {
+ for (int l = 0; l < tile_C_VKQ::ne; ++l) {
VKQ_C[i].x[l] *= KQ_max_scale_h2;
}
}
+ }
+
+ // Convert KQ C tiles into B tiles for VKQ calculation:
+ tile_B B[KQ_stride/(np*2*tile_B::J)];
+ static_assert(KQ_stride % (np*2*tile_B::J) == 0, "bad loop size");
+#pragma unroll
+ for (int k = 0; k < KQ_stride/(np*2*tile_B::J); ++k) {
+ B[k] = get_transposed(get_half2(KQ_C[k]));
+ }
- // Convert KQ C tiles into B tiles for VKQ calculation:
- mma_B B[KQ_stride/(np*2*mma_B::K)];
- static_assert(KQ_stride % (np*2*mma_B::K) == 0, "bad loop size");
+#ifdef CP_ASYNC_AVAILABLE
+ cp_async_wait_all();
+ __syncthreads();
+ if (!last_iter) {
+ flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + (k_VKQ_0 + KQ_stride)*stride_KV, tile_K, stride_KV);
+ }
+#else
+ flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
+ __syncthreads();
+#endif // CP_ASYNC_AVAILABLE
+
+ // Calculate VKQ tile:
+#pragma unroll
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) {
+ static_assert((KQ_stride/2) % (np*tile_A::J) == 0, "bad loop size");
#pragma unroll
- for (int k = 0; k < KQ_stride/(np*2*mma_B::K); ++k) {
- B[k] = KQ_C[k].to_mma_B();
+ for (int k00 = 0; k00 < KQ_stride/2; k00 += np*tile_A::J) {
+ const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
+
+ tile_A A;
+ load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
+ mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
+ }
+ }
+
+#ifndef CP_ASYNC_AVAILABLE
+ __syncthreads(); // Only needed if tile_K == tile_V.
+#endif // CP_ASYNC_AVAILABLE
+
+#else
+ NO_DEVICE_CODE;
+#endif // NEW_MMA_AVAILABLE
+}
+
+template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
+static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
+ const float2 * const __restrict__ Q_f2,
+ const half2 * const __restrict__ K_h2,
+ const half2 * const __restrict__ V_h2,
+ const half * const __restrict__ maskh,
+ float2 * const __restrict__ dstk,
+ float2 * const __restrict__ dstk_fixup,
+ const float scale,
+ const float slope,
+ const float logit_softcap,
+ const int ne01,
+ const int ne02,
+ const int stride_Q,
+ const int stride_KV,
+ const int stride_mask,
+ const int jt,
+ const int kb0_start,
+ const int kb0_stop) {
+#ifdef NEW_MMA_AVAILABLE
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+ static_assert(nwarps*tile_B::I % ncols == 0, "bad nwarps");
+ constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column.
+
+ static_assert(D % nwarps == 0, "bad D");
+ static_assert(KQ_stride % nwarps == 0, "bad KQ_stride");
+
+ constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
+
+ // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements:
+ extern __shared__ half2 tile_K[];
+#ifdef CP_ASYNC_AVAILABLE
+ half2 * tile_V = tile_K + KQ_stride*D2_padded;
+#else
+ half2 * tile_V = tile_K;
+#endif // CP_ASYNC_AVAILABLE
+
+ tile_B Q_B[D/(2*tile_B::J)];
+ tile_C_VKQ VKQ_C[D/tile_C_VKQ::I];
+
+ float2 KQ_rowsum = {0.0f, 0.0f};
+ float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f};
+
+ // Temporarily load Q data into tile_K, will be loaded into registers afterwards.
+ // The loading is done with decreasing granularity for D for better memory bandwidth.
+ const half2 scale_h2 = make_half2(scale, scale);
+#pragma unroll
+ for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
+ const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
+ const int k0_stop = D/2 - (D/2) % (1*stride_k);
+ const int stride_j = WARP_SIZE / stride_k;
+
+ if (k0_start == k0_stop) {
+ continue;
+ }
+
+ if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
+ break;
}
- // Load V data into tile with decreasing granularity for D for better memory bandwidth:
- static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
#pragma unroll
- for (int stride_i : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
- const int i0_start = stride_i == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_i);
- const int i0_stop = D/2 - (D/2) % (1*stride_i);
- const int stride_k = WARP_SIZE / stride_i;
+ for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) {
+ const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+ if (jt*ncols + j < ne01) {
#pragma unroll
- for (int k_V_0 = 0; k_V_0 < KQ_stride; k_V_0 += nwarps*stride_k) {
- const int k_V = k_V_0 + threadIdx.y*stride_k + (stride_i == WARP_SIZE ? 0 : threadIdx.x / stride_i);
+ for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+ const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+ const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k];
+ tile_K[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
+ }
+ } else {
#pragma unroll
- for (int i_V_0 = i0_start; i_V_0 < i0_stop; i_V_0 += stride_i) {
- const int i_V = i_V_0 + (stride_i == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_i);
+ for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+ const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
- tile_KV[k_V*D2_padded + i_V] = V_h2[(k_VKQ_0 + k_V)*stride_KV + i_V];
+ tile_K[j*D2_padded + k] = make_half2(0.0f, 0.0f);
}
}
}
+ }
- __syncthreads();
+ __syncthreads();
- // Calculate VKQ tile:
-#pragma unroll
- for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += mma_C_VKQ::I) {
- static_assert((KQ_stride/2) % (np*mma_A::K) == 0, "bad loop size");
-#pragma unroll
- for (int k00 = 0; k00 < KQ_stride/2; k00 += np*mma_A::K) {
- const int k0 = k00 + (threadIdx.y % np)*mma_A::K;
+ {
+ const int j0 = (threadIdx.y / np) * tile_B::I;
- mma_A A;
- A.load_ldmatrix_trans(tile_KV + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
- VKQ_C[i_VKQ_0/mma_C_VKQ::I].mma(A, B[k00/(np*mma_A::K)]);
- }
+#pragma unroll
+ for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
+ load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded);
}
+ }
+
+ __syncthreads();
+ // Preload K data for first iteration when using cp_async:
+#ifdef CP_ASYNC_AVAILABLE
+ flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + kb0_start*KQ_stride*stride_KV, tile_K, stride_KV);
+#endif // CP_ASYNC_AVAILABLE
+
+ // Iterate over ne11 == previous tokens:
+ for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
+ constexpr bool last_iter = false;
+ flash_attn_ext_f16_iter<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup, last_iter>
+ (Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap,
+ ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
+ }
+ { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
+ constexpr bool last_iter = true;
+ flash_attn_ext_f16_iter<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup, last_iter>
+ (Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap,
+ ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
+ }
+
+ // With cp_async there is no __syncthreads at the end of the iter,
+ // there can be a race condition on shared memory access for combining/writing back results.
+#ifdef CP_ASYNC_AVAILABLE
+ if (nwarps*tile_B::I > KQ_stride) {
__syncthreads();
}
+#endif // CP_ASYNC_AVAILABLE
// Finally, sum up partial KQ rowsums.
// The partial sums are spread across 8 threads each, does not need full reduce.
// Write VKQ accumulators to shared memory in column-major format.
// It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
// Also for np > 1 the combination is done via these values in shared memory.
- const int j_cwd = threadIdx.y*mma_B::J + mma_B::get_j(-1); // j combine write data
+ const int j_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // j combine write data
#pragma unroll
- for (int k0 = 0; k0 < D/2; k0 += mma_B::K) {
- const mma_B B = VKQ_C[k0/mma_B::K].to_mma_B(); // Conversion of C to B matrix puts it in column-major format.
+ for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
+ const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format.
#pragma unroll
- for (int l = 0; l < mma_B::ne; ++l) {
- const int k = k0 + mma_B::get_k(l);
+ for (int l = 0; l < tile_B::ne; ++l) {
+ const int k = k0 + tile_B::get_j(l);
- tile_KV[j_cwd*D2_padded + k] = B.x[l];
+ tile_K[j_cwd*D2_padded + k] = B.x[l];
}
}
- const int j_cwmo = (threadIdx.x % (2*mma_C_VKQ::J)) / mma_C_VKQ::J; // j combine write meta offset
- const int j_cwm = threadIdx.y*(2*mma_C_VKQ::J) + 2*mma_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta
+ const int j_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // j combine write meta offset
+ const int j_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta
const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum
- if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*mma_C_VKQ::J) {
+ if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) {
// Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
- ((float2 *) tile_KV)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr;
+ ((float2 *) tile_K)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr;
}
__syncthreads();
static_assert(np == 1 || np == 2 || np == 4, "bad np");
if (np == 1) {
// No combination is needed, the meta data can be directly written from registers to VRAM.
- if (needs_fixup && threadIdx.x < mma_B::J) {
+ if (needs_fixup && threadIdx.x < tile_B::I) {
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
dstk_fixup_meta[j_cwm] = KQ_cmr;
}
- if (is_fixup && threadIdx.x < mma_B::J) {
+ if (is_fixup && threadIdx.x < tile_B::I) {
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
dstk_fixup_meta[j_cwm] = KQ_cmr;
}
// Warps with threadIdx.y % np != 0 must NOT return early.
// All threads must return simultaneously to avoid race conditions with work on the next tile.
- float * meta_j = (float *) tile_KV + (threadIdx.y*mma_B::J + threadIdx.x)*D2_padded + D/2;
+ float * meta_j = (float *) tile_K + (threadIdx.y*tile_B::I + threadIdx.x)*D2_padded + D/2;
float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp.
- if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
+ if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
KQ_cm = meta_j[0];
}
float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps.
#pragma unroll
- for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) {
+ for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) {
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
}
const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp.
float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps.
- if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
+ if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
KQ_crs = KQ_cms*meta_j[1];
}
#pragma unroll
- for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) {
+ for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) {
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
}
// Write back combined meta data:
- if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
- meta_j[0] = KQ_cmn; // Combined max. KQ values.
- meta_j[1] = KQ_crs; // Combined KQ rowsums.
- meta_j[2] = KQ_cms; // KQ max scales per parallel warp.
+ if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
+ *((float2 *) meta_j) = make_float2(KQ_cms, KQ_crs); // Combined KQ max scale + rowsum.
}
- if (needs_fixup && threadIdx.x < mma_B::J) {
+ if (needs_fixup && threadIdx.x < tile_B::I) {
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
- dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
+ dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
}
- if (is_fixup && threadIdx.x < mma_B::J) {
+ if (is_fixup && threadIdx.x < tile_B::I) {
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
- dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
+ dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
}
}
const int k0_stop = D/2 - (D/2) % (1*stride_k);
const int stride_j = WARP_SIZE / stride_k;
+ if (k0_start == k0_stop) {
+ continue;
+ }
+
if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
break;
}
#pragma unroll
for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) {
const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
- const int j_tile_KV = (j_dst/mma_B::J)*(np*mma_B::J) + j_dst % mma_B::J;
+ const int j_tile_K = (j_dst/tile_B::I)*(np*tile_B::I) + j_dst % tile_B::I;
if (!is_fixup && jt*ncols + j_dst >= ne01) {
continue;
}
- const float * meta_j = (const float *) tile_KV + j_tile_KV*D2_padded + D/2;
+ const float * meta_j = (const float *) tile_K + j_tile_K*D2_padded + D/2;
#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
float2 dstk_val = make_float2(0.0f, 0.0f);
#pragma unroll
for (int ip = 0; ip < np; ++ip) {
- const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*mma_B::J*D2_padded + 2];
- const float2 dstk_val_add = __half22float2(tile_KV[(j_tile_KV + ip*mma_B::J)*D2_padded + k]);
+ const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*tile_B::I*D2_padded + 0];
+ const float2 dstk_val_add = __half22float2(tile_K[(j_tile_K + ip*tile_B::I)*D2_padded + k]);
dstk_val.x += dstk_val_add.x*KQ_crs;
dstk_val.y += dstk_val_add.y*KQ_crs;
}
__syncthreads();
}
#else
- NO_DEVICE_CODE;
+ NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
const int ne1,
const int ne2,
const int ne3) {
+#ifndef NEW_MMA_AVAILABLE
+ NO_DEVICE_CODE;
+ return;
+#endif // NEW_MMA_AVAILABLE
+
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+ const int stride_Q = nb01 / sizeof(float2);
+ const int stride_KV = nb11 / sizeof(half2);
+ const int stride_mask = nb31 / sizeof(half);
+
const int iter_k = ne11 / KQ_stride;
const int iter_j = (ne01 + (ncols - 1)) / ncols;
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
- ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
- jt, kb0_start, kb0_stop);
+ ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
} else {
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
- ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
- jt, kb0_start, kb0_stop);
+ ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
}
kbc += iter_k;
constexpr bool needs_fixup = false;
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
- ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
- jt, kb0_start, kb0_stop);
+ ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
}
template <int D, int cols_per_block>
void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- typedef mma_A_I16K8<half2> mma_A;
- typedef mma_B_J8K8<half2> mma_B;
+ typedef tile<16, 8, half2> tile_A;
+ typedef tile< 8, 8, half2> tile_B;
- static_assert(D % mma_B::K == 0, "bad D");
- static_assert(cols_per_block % mma_B::J == 0, "bad cols_per_block");
+ static_assert(D % tile_B::J == 0, "bad D");
+ static_assert(cols_per_block % tile_B::I == 0, "bad cols_per_block");
const ggml_tensor * KQV = dst;
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+
+ constexpr int KQ_stride = D <= 128 ? 64 : 32;
+ constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ?
+ cols_per_block/tile_B::J * KQ_stride/tile_A::I : (cols_per_block <= 8 ? 4 : 8);
- constexpr int KQ_stride = D <= 128 ? 64 : 32;
- constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ?
- cols_per_block/mma_B::J * KQ_stride/mma_A::I : (cols_per_block <= 8 ? 4 : 8);
- constexpr size_t nbytes_shared = std::max(KQ_stride, nwarps*mma_B::J) * (D + 8) * sizeof(half);
+ const int nrows_KQ = cp_async_available(cc) ? 2*KQ_stride : KQ_stride;
+ const int nrows_combine = nwarps*tile_B::J;
+ const size_t nbytes_shared = std::max(nrows_KQ, nrows_combine) * (D + 8) * sizeof(half);
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction
//
// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C.
-// A is a row-major matrix with shape I x K.
-// B is a column-major matrix with shape K x J.
-// C is a column-major matrix with shape I x J.
-// Note that along their lowest dimension I, J, and K are measured in physical 32 bit elements instead of logical elements.
-// The functions get_i, get_j, and get_k can be used to get the physical 32 bit index of the lth element of a thread within a tile.
+// A is a row-major matrix with shape M x K.
+// B is a column-major matrix with shape K x N.
+// C is a column-major matrix with shape M x N.
+// A, B, and C are represented using the same fundamental data type: a row-major matrix with I rows and J columns.
+// Note that J is measured in physical 32 bit elements instead of logical elements.
+// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
// All matrix tiles have ne physical 32 bit elements per warp.
//
// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
#ifdef NEW_MMA_AVAILABLE
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
- : "+r"(ret) : "r"(x));
+ : "=r"(ret) : "r"(x));
#else
NO_DEVICE_CODE;
#endif // defined(NEW_MMA_AVAILABLE)
#endif // CUDART_VERSION >= 11080
+static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
+ half2 ret;
+ *((int *) &ret) = ggml_cuda_movmatrix(*((const int *) &x));
+ return ret;
+}
-template <typename T>
-struct mma_A_I16K4 {
- static_assert(sizeof(T) == 4, "bad type size");
-
- static constexpr int I = 16;
- static constexpr int K = 4;
- static constexpr int ne = 2;
-
- T x[ne];
+namespace ggml_cuda_mma {
+
+ template <int I_, int J_, typename T>
+ struct tile {
+ static constexpr int I = I_;
+ static constexpr int J = J_;
+ static constexpr int ne = I * J / WARP_SIZE;
+ T x[ne] = {0};
+
+ static __device__ __forceinline__ int get_i(const int l) {
+ if constexpr (I == 8 && (J == 4 || J == 8)) {
+ return threadIdx.x / 4;
+ } else if constexpr (I == 16 && J == 8) {
+ return (l / 2) * 8 + threadIdx.x / 4;
+ } else {
+ static_assert(I == -1 && J == -1, "template specialization not implemented");
+ }
+ }
- static __device__ __forceinline__ int get_i(const int l) {
- const int ret = (l%2) * (I/2) + threadIdx.x / K;
- GGML_CUDA_ASSUME(ret >= 0);
- GGML_CUDA_ASSUME(ret < I);
- return ret;
- }
+ static __device__ __forceinline__ int get_j(const int l) {
+ if constexpr (I == 8 && J == 4) {
+ return threadIdx.x % 4;
+ } else if constexpr (I == 8 && J == 8) {
+ return 4 * l + threadIdx.x % 4;
+ } else if constexpr (I == 16 && J == 8) {
+ return 2 * (threadIdx.x % 4) + l % 2;
+ } else {
+ static_assert(I == -1 && J == -1, "template specialization not implemented");
+ }
+ }
+ };
+
+ template <int I_, int J_>
+ struct tile<I_, J_, half2> {
+ static constexpr int I = I_;
+ static constexpr int J = J_;
+ static constexpr int ne = I * J / WARP_SIZE;
+ half2 x[ne] = {{0.0f, 0.0f}};
+
+ static __device__ __forceinline__ int get_i(const int l) {
+ if constexpr (I == 8 && J == 8) {
+ return threadIdx.x / 4;
+ } else if constexpr (I == 16 && J == 4) {
+ return l * 8 + threadIdx.x / 4;
+ } else if constexpr (I == 16 && J == 8) {
+ return (l % 2) * 8 + threadIdx.x / 4;
+ } else {
+ static_assert(I == -1 && J == -1, "template specialization not implemented");
+ }
+ }
- static __device__ __forceinline__ int get_k(const int /* l */) {
- const int ret = threadIdx.x % K;
- GGML_CUDA_ASSUME(ret >= 0);
- GGML_CUDA_ASSUME(ret < K);
- return ret;
- }
+ static __device__ __forceinline__ int get_j(const int l) {
+ if constexpr (I == 8 && J == 8) {
+ return l * 4 + threadIdx.x % 4;
+ } else if constexpr (I == 16 && J == 4) {
+ return threadIdx.x % 4;
+ } else if constexpr (I == 16 && J == 8) {
+ return (l / 2) * 4 + threadIdx.x % 4;
+ } else {
+ static_assert(I == -1 && J == -1, "template specialization not implemented");
+ }
+ }
+ };
- __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
+ template <int I, int J>
+ static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
+ tile<I, J/2, half2> ret;
#pragma unroll
- for (int l = 0; l < ne; ++l) {
- x[l] = xs0[get_i(l)*stride + get_k(l)];
+ for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {
+ ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
}
- }
-
- __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
-#ifdef NEW_MMA_AVAILABLE
- int * xi = (int *) x;
- const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride;
- asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
- : "+r"(xi[0]), "+r"(xi[1])
- : "l"(xs));
-#else
- load_generic(xs0, stride);
-#endif // NEW_MMA_AVAILABLE
- }
-};
-
-template <typename T>
-struct mma_A_I16K8 {
- static_assert(sizeof(T) == 4, "bad type size");
-
- static constexpr int I = 16;
- static constexpr int K = 8;
- static constexpr int ne = 4;
-
- T x[ne];
-
- static __device__ __forceinline__ int get_i(const int l) {
- const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
- GGML_CUDA_ASSUME(ret >= 0);
- GGML_CUDA_ASSUME(ret < I);
return ret;
}
- static __device__ __forceinline__ int get_k(const int l) {
- const int ret = (l/2) * (K/2) + threadIdx.x % (K/2);
- GGML_CUDA_ASSUME(ret >= 0);
- GGML_CUDA_ASSUME(ret < K);
+ static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {
+ tile<8, 8, half2> ret;
+ ret.x[0] = ggml_cuda_movmatrix(t.x[0]);
+ ret.x[1] = ggml_cuda_movmatrix(t.x[1]);
+
return ret;
}
- __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
+ template <int I, int J, typename T>
+ static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
#pragma unroll
- for (int l = 0; l < ne; ++l) {
- x[l] = xs0[get_i(l)*stride + get_k(l)];
+ for (int l = 0; l < t.ne; ++l) {
+ t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
}
}
- __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
+ template <typename T>
+ static __device__ __forceinline__ void load_ldmatrix(
+ tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef NEW_MMA_AVAILABLE
- int * xi = (int * ) x;
- const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
- asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
- : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
+ int * xi = (int *) t.x;
+ const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
+ asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
+ : "=r"(xi[0]), "=r"(xi[1])
: "l"(xs));
#else
- GGML_UNUSED(xs0);
- GGML_UNUSED(stride);
- NO_DEVICE_CODE;
+ load_generic(t, xs0, stride);
#endif // NEW_MMA_AVAILABLE
}
- __device__ __forceinline__ void load_ldmatrix_trans(const T * __restrict__ xs0, const int & stride) {
+ template <typename T>
+ static __device__ __forceinline__ void load_ldmatrix(
+ tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef NEW_MMA_AVAILABLE
- int * xi = (int * ) x;
- const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
- asm("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
- : "+r"(xi[0]), "+r"(xi[2]), "+r"(xi[1]), "+r"(xi[3])
+ int * xi = (int *) t.x;
+ const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
+ asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
+ : "=r"(xi[0]), "=r"(xi[1])
: "l"(xs));
#else
- GGML_UNUSED(xs0);
- GGML_UNUSED(stride);
- NO_DEVICE_CODE;
+ load_generic(xs0, stride);
#endif // NEW_MMA_AVAILABLE
}
- __device__ __forceinline__ void transpose() {
- int * xi = (int *) x;
- xi[0] = ggml_cuda_movmatrix(xi[0]);
-
- const int tmp = ggml_cuda_movmatrix(xi[1]);
- xi[1] = ggml_cuda_movmatrix(xi[2]);
- xi[2] = tmp;
-
- xi[3] = ggml_cuda_movmatrix(xi[3]);
- }
-};
-
-template <typename T>
-struct mma_B_J8K4 {
- static_assert(sizeof(T) == 4, "bad type size");
-
- static constexpr int J = 8;
- static constexpr int K = 4;
- static constexpr int ne = 1;
-
- T x[ne];
-
- static __device__ __forceinline__ int get_j(const int /* l */) {
- const int ret = threadIdx.x / K;
- GGML_CUDA_ASSUME(ret >= 0);
- GGML_CUDA_ASSUME(ret < J);
- return ret;
- }
-
- static __device__ __forceinline__ int get_k(const int /* l */) {
- const int ret = threadIdx.x % K;
- GGML_CUDA_ASSUME(ret >= 0);
- GGML_CUDA_ASSUME(ret < K);
- return ret;
- }
-
- __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
-#pragma unroll
- for (int l = 0; l < ne; ++l) {
- x[l] = xs0[get_j(l)*stride + get_k(l)];
- }
- }
-
- __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
+ template <typename T>
+ static __device__ __forceinline__ void load_ldmatrix(
+ tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef NEW_MMA_AVAILABLE
- int * xi = (int *) x;
- const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride;
- asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
- : "+r"(xi[0]) : "l"(xs));
+ int * xi = (int * ) t.x;
+ const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
+ asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
+ : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
+ : "l"(xs));
#else
- load_generic(xs0, stride);
+ load_generic(t, xs0, stride);
#endif // NEW_MMA_AVAILABLE
}
-};
-
-template <typename T>
-struct mma_B_J8K8 {
- static_assert(sizeof(T) == 4, "bad type size");
-
- static constexpr int J = 8;
- static constexpr int K = 8;
- static constexpr int ne = 2;
- T x[ne];
-
- static __device__ __forceinline__ int get_j(const int /* l */) {
- const int ret = threadIdx.x / (K/2);
- GGML_CUDA_ASSUME(ret >= 0);
- GGML_CUDA_ASSUME(ret < J);
- return ret;
- }
-
- static __device__ __forceinline__ int get_k(const int l) {
- const int ret = l * (K/2) + threadIdx.x % (K/2);
- GGML_CUDA_ASSUME(ret >= 0);
- GGML_CUDA_ASSUME(ret < K);
- return ret;
- }
-
- __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
-#pragma unroll
- for (int l = 0; l < ne; ++l) {
- x[l] = xs0[get_j(l)*stride + get_k(l)];
- }
- }
-
- __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
+ template <typename T>
+ static __device__ __forceinline__ void load_ldmatrix_trans(
+ tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef NEW_MMA_AVAILABLE
- int * xi = (int *) x;
- const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
- asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
- : "+r"(xi[0]), "+r"(xi[1])
+ int * xi = (int * ) t.x;
+ const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
+ asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
+ : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3])
: "l"(xs));
#else
- load_generic(xs0, stride);
+ GGML_UNUSED(t);
+ GGML_UNUSED(xs0);
+ GGML_UNUSED(stride);
+ NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
-};
-
-template <typename T>
-struct mma_C_I16J8 {};
-
-template <>
-struct mma_C_I16J8<int> {
- static constexpr int I = 16;
- static constexpr int J = 8;
- static constexpr int ne = 4;
- int x[ne] = {0};
-
- static __device__ __forceinline__ int get_i(const int l) {
- const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
- GGML_CUDA_ASSUME(ret >= 0);
- GGML_CUDA_ASSUME(ret < I);
- return ret;
- }
-
- static __device__ __forceinline__ int get_j(const int l) {
- const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
- GGML_CUDA_ASSUME(ret >= 0);
- GGML_CUDA_ASSUME(ret < J);
- return ret;
- }
-
- __device__ __forceinline__ void mma(const mma_A_I16K4<int> & mma_A, const mma_B_J8K4<int> & mma_B) {
+ static __device__ __forceinline__ void mma(
+ tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
#ifdef NEW_MMA_AVAILABLE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
- : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
- : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+ : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
+ : "r"(A.x[0]), "r"(A.x[1]), "r"(B.x[0]));
#else
// On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
- : "+r"(x[0]), "+r"(x[1])
- : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
+ : "+r"(D.x[0]), "+r"(D.x[1])
+ : "r"(A.x[0]), "r"(B.x[0]));
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
- : "+r"(x[2]), "+r"(x[3])
- : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+ : "+r"(D.x[2]), "+r"(D.x[3])
+ : "r"(A.x[1]), "r"(B.x[0]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#else
- GGML_UNUSED(mma_A);
- GGML_UNUSED(mma_B);
+ GGML_UNUSED(D);
+ GGML_UNUSED(A);
+ GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
- __device__ __forceinline__ void mma(const mma_A_I16K8<int> & mma_A, const mma_B_J8K8<int> & mma_B) {
+ static __device__ __forceinline__ void mma(
+ tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
#ifdef NEW_MMA_AVAILABLE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
- : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
- : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
+ : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
+ : "r"(A.x[0]), "r"(A.x[1]), "r"(A.x[2]), "r"(A.x[3]), "r"(B.x[0]), "r"(B.x[1]));
#else
// On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
- : "+r"(x[0]), "+r"(x[1])
- : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
+ : "+r"(D.x[0]), "+r"(D.x[1])
+ : "r"(A.x[0]), "r"(B.x[0]));
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
- : "+r"(x[2]), "+r"(x[3])
- : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+ : "+r"(D.x[2]), "+r"(D.x[3])
+ : "r"(A.x[1]), "r"(B.x[0]));
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
- : "+r"(x[0]), "+r"(x[1])
- : "r"(mma_A.x[2]), "r"(mma_B.x[1]));
+ : "+r"(D.x[0]), "+r"(D.x[1])
+ : "r"(A.x[2]), "r"(B.x[1]));
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
- : "+r"(x[2]), "+r"(x[3])
- : "r"(mma_A.x[3]), "r"(mma_B.x[1]));
+ : "+r"(D.x[2]), "+r"(D.x[3])
+ : "r"(A.x[3]), "r"(B.x[1]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#else
- GGML_UNUSED(mma_A);
- GGML_UNUSED(mma_B);
+ GGML_UNUSED(D);
+ GGML_UNUSED(A);
+ GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
-};
-
-template <>
-struct mma_C_I16J8<half2> {
- static constexpr int I = 16;
- static constexpr int J = 4;
- static constexpr int ne = 2;
-
- half2 x[ne] = {{0.0f, 0.0f}, {0.0f, 0.0f}};
-
- static __device__ __forceinline__ int get_i(const int l) {
- const int ret = l * (I/2) + threadIdx.x / J;
- GGML_CUDA_ASSUME(ret >= 0);
- GGML_CUDA_ASSUME(ret < I);
- return ret;
- }
- static __device__ __forceinline__ int get_j(const int /* l */) {
- const int ret = threadIdx.x % J;
- GGML_CUDA_ASSUME(ret >= 0);
- GGML_CUDA_ASSUME(ret < J);
- return ret;
- }
-
- __device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
+ static __device__ __forceinline__ void mma(
+ tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
#ifdef NEW_MMA_AVAILABLE
- int * Axi = (int *) mma_A.x;
- int * Bxi = (int *) mma_B.x;
- int * xi = (int *) x;
+ const int * Axi = (const int *) A.x;
+ const int * Bxi = (const int *) B.x;
+ int * Dxi = (int *) D.x;
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
- : "+r"(xi[0]), "+r"(xi[1])
+ : "+r"(Dxi[0]), "+r"(Dxi[1])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
#else
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
- : "+r"(xi[0]), "+r"(xi[1])
+ : "+r"(Dxi[0]), "+r"(Dxi[1])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
- : "+r"(xi[0]), "+r"(xi[1])
+ : "+r"(Dxi[0]), "+r"(Dxi[1])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#else
- GGML_UNUSED(mma_A);
- GGML_UNUSED(mma_B);
+ GGML_UNUSED(D);
+ GGML_UNUSED(A);
+ GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
- __device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
- mma_B_J8K8<half2> mma_B;
-
- int * xi = (int *) x;
- int * Bxi = (int *) mma_B.x;
- Bxi[0] = ggml_cuda_movmatrix(xi[0]);
- Bxi[1] = ggml_cuda_movmatrix(xi[1]);
-
- return mma_B;
- }
-};
-
-template <>
-struct mma_C_I16J8<float> {
- static constexpr int I = 16;
- static constexpr int J = 8;
- static constexpr int ne = 4;
-
- float x[ne] = {0.0f, 0.0f, 0.0f, 0.0f};
-
- static __device__ __forceinline__ int get_i(const int l) {
- const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
- GGML_CUDA_ASSUME(ret >= 0);
- GGML_CUDA_ASSUME(ret < I);
- return ret;
- }
-
- static __device__ __forceinline__ int get_j(const int l) {
- const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
- GGML_CUDA_ASSUME(ret >= 0);
- GGML_CUDA_ASSUME(ret < J);
- return ret;
- }
-
- __device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
+ static __device__ __forceinline__ void mma(
+ tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
#ifdef NEW_MMA_AVAILABLE
- int * Axi = (int *) mma_A.x;
- int * Bxi = (int *) mma_B.x;
- int * xi = (int *) x;
+ const int * Axi = (const int *) A.x;
+ const int * Bxi = (const int *) B.x;
+ int * Dxi = (int *) D.x;
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
- : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
#else
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
- : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
- : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#else
- GGML_UNUSED(mma_A);
- GGML_UNUSED(mma_B);
+ GGML_UNUSED(D);
+ GGML_UNUSED(A);
+ GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
- __device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
- mma_B_J8K8<half2> mma_B;
- mma_B.x[0] = make_half2(x[0], x[1]);
- mma_B.x[1] = make_half2(x[2], x[3]);
-
- int * Bxi = (int *) mma_B.x;
- Bxi[0] = ggml_cuda_movmatrix(Bxi[0]);
- Bxi[1] = ggml_cuda_movmatrix(Bxi[1]);
-
- return mma_B;
- }
-
- __device__ __forceinline__ void load_generic(const float * __restrict__ xs0, const int & stride) {
-#pragma unroll
- for (int l = 0; l < ne; ++l) {
- x[l] = xs0[get_j(l)*stride + get_i(l)];
- }
- }
-};
+}
#include <climits>
#include <cstdint>
+using namespace ggml_cuda_mma;
+
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
#define MMQ_ITER_K 256
#define MMQ_NWARPS 8
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
- typedef mma_A_I16K8<int> mma_A;
- typedef mma_B_J8K8<int> mma_B;
- typedef mma_C_I16J8<int> mma_C;
+ typedef tile<16, 8, int> tile_A;
+ typedef tile< 8, 8, int> tile_B;
+ typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
- constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
- y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+ y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
const float * y_df = (const float *) y;
const half2 * y_ds = (const half2 *) y;
- mma_A A[ntx][WARP_SIZE/QI8_0];
- float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
+ tile_A A[ntx][WARP_SIZE/QI8_0];
+ float dA[ntx][tile_C::ne/2][WARP_SIZE/QI8_0];
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
const int k0 = k00 + k01;
- A[n][k01/QI8_0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
+ load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
}
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
+ for (int l = 0; l < tile_C::ne/2; ++l) {
+ const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
}
#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
- mma_B B;
- float dB[mma_C::ne/2];
+ tile_B B;
+ float dB[tile_C::ne/2];
- B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int j = j0 + mma_C::get_j(l);
+ for (int l = 0; l < tile_C::ne/2; ++l) {
+ const int j = j0 + tile_C::get_j(l);
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
- mma_C C;
- C.mma(A[n][k01/QI8_0], B);
+ tile_C C;
+ mma(C, A[n][k01/QI8_0], B);
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
+ for (int l = 0; l < tile_C::ne; ++l) {
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
}
}
}
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
- typedef mma_A_I16K8<int> mma_A;
- typedef mma_B_J8K8<int> mma_B;
- typedef mma_C_I16J8<int> mma_C;
+ typedef tile<16, 8, int> tile_A;
+ typedef tile< 8, 8, int> tile_B;
+ typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
- constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
- y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+ y += (threadIdx.y % ntx) * (tile_B::J*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
const int * y_qs = (const int *) y + 4;
const half2 * y_dm = (const half2 *) y;
- mma_A A[ntx][WARP_SIZE/QI8_1];
- float2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1];
+ tile_A A[ntx][WARP_SIZE/QI8_1];
+ float2 dmA[ntx][tile_C::ne/2][WARP_SIZE/QI8_1];
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
const int k0 = k00 + k01;
- A[n][k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
+ load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
}
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
+ for (int l = 0; l < tile_C::ne/2; ++l) {
+ const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
}
#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
- mma_B B;
- float2 dsB[mma_C::ne/2];
+ tile_B B;
+ float2 dsB[tile_C::ne/2];
- B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int j = j0 + mma_C::get_j(l);
+ for (int l = 0; l < tile_C::ne/2; ++l) {
+ const int j = j0 + tile_C::get_j(l);
dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
- mma_C C;
- C.mma(A[n][k01/QI8_1], B);
+ tile_C C;
+ mma(C, A[n][k01/QI8_1], B);
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
+ for (int l = 0; l < tile_C::ne; ++l) {
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
}
}
}
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
#ifdef NEW_MMA_AVAILABLE
- typedef mma_A_I16K4<int> mma_A;
- typedef mma_A_I16K8<int> mma_A_K8;
- typedef mma_B_J8K4<int> mma_B;
- typedef mma_C_I16J8<int> mma_C;
+ typedef tile<16, 4, int> tile_A;
+ typedef tile<16, 8, int> tile_A_8;
+ typedef tile< 8, 4, int> tile_B;
+ typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
- constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
- y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+ y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + WARP_SIZE*2;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
- const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+ const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
- mma_A A[ntx][8];
- float dA[ntx][mma_C::ne/2][8];
+ tile_A A[ntx][8];
+ float dA[ntx][tile_C::ne/2][8];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
const int k0 = k00 + k01;
- ((mma_A_K8 *) A[n])[k01/8].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
+ load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
}
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+ for (int l = 0; l < tile_C::ne/2; ++l) {
+ const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
}
#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
- mma_B B[2];
- float dB[mma_C::ne/2];
+ tile_B B[2];
+ float dB[tile_C::ne/2];
// Here load_generic is faster than load_ldmatrix.
- B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
- B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
+ load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
+ load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int j = j0 + mma_C::get_j(l);
+ for (int l = 0; l < tile_C::ne/2; ++l) {
+ const int j = j0 + tile_C::get_j(l);
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
- mma_C C[2];
- C[0].mma(A[n][k01/4 + 0], B[0]);
- C[1].mma(A[n][k01/4 + 1], B[1]);
+ tile_C C[2];
+ mma(C[0], A[n][k01/4 + 0], B[0]);
+ mma(C[1], A[n][k01/4 + 1], B[1]);
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
+ for (int l = 0; l < tile_C::ne; ++l) {
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
}
}
}
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
#ifdef NEW_MMA_AVAILABLE
- typedef mma_A_I16K4<int> mma_A;
- typedef mma_A_I16K8<int> mma_A_K8;
- typedef mma_B_J8K4<int> mma_B;
- typedef mma_C_I16J8<int> mma_C;
+ typedef tile<16, 4, int> tile_A;
+ typedef tile<16, 8, int> tile_A_8;
+ typedef tile< 8, 4, int> tile_B;
+ typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
- constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
- y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+ y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
- const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+ const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
- mma_A A[ntx][8];
- float dA[ntx][mma_C::ne/2][8];
- float mA[ntx][mma_C::ne/2][8];
+ tile_A A[ntx][8];
+ float dA[ntx][tile_C::ne/2][8];
+ float mA[ntx][tile_C::ne/2][8];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
const int k0 = k00 + k01;
- ((mma_A_K8 *) A[n])[k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
+ load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
}
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+ for (int l = 0; l < tile_C::ne/2; ++l) {
+ const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
}
#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
- float2 dB[mma_C::ne/2];
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+ float2 dB[tile_C::ne/2];
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int j = j0 + mma_C::get_j(l);
+ for (int l = 0; l < tile_C::ne/2; ++l) {
+ const int j = j0 + tile_C::get_j(l);
dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
}
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
- mma_B B[2];
+ tile_B B[2];
// Here load_generic is faster than load_ldmatrix.
- B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
- B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
+ load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
+ load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
- mma_C Cm[2];
+ tile_C Cm[2];
if (k01 >= WARP_SIZE * 3/4) {
- mma_A A1;
+ tile_A A1;
A1.x[0] = 0x01010101;
A1.x[1] = 0x01010101;
- Cm[0].mma(A1, B[0]);
- Cm[1].mma(A1, B[1]);
+ mma(Cm[0], A1, B[0]);
+ mma(Cm[1], A1, B[1]);
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
- mma_C Cd[2];
+ tile_C Cd[2];
- Cd[0].mma(A[n][k01/4 + 0], B[0]);
- Cd[1].mma(A[n][k01/4 + 1], B[1]);
+ mma(Cd[0], A[n][k01/4 + 0], B[0]);
+ mma(Cd[1], A[n][k01/4 + 1], B[1]);
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
+ for (int l = 0; l < tile_C::ne; ++l) {
float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
if (k01 >= WARP_SIZE * 3/4) {
tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
}
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
}
}
}
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
- float2 sB[mma_C::ne/2];
+ float2 sB[tile_C::ne/2];
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int j = j0 + mma_C::get_j(l);
+ for (int l = 0; l < tile_C::ne/2; ++l) {
+ const int j = j0 + tile_C::get_j(l);
sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
- sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
+ for (int l = 0; l < tile_C::ne; ++l) {
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
}
}
}
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
#ifdef NEW_MMA_AVAILABLE
- typedef mma_A_I16K4<int> mma_A;
- typedef mma_B_J8K4<int> mma_B;
- typedef mma_C_I16J8<int> mma_C;
+ typedef tile<16, 4, int> tile_A;
+ typedef tile< 8, 4, int> tile_B;
+ typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
- constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
- y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+ y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + WARP_SIZE*2;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
- const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+ const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
- mma_A A[ntx][8];
- int scA[ntx][mma_C::ne/2][8];
- float dA[ntx][mma_C::ne/2];
+ tile_A A[ntx][8];
+ int scA[ntx][tile_C::ne/2][8];
+ float dA[ntx][tile_C::ne/2];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
const int k0 = k00 + k01;
- A[n][k01/4 + 0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
- A[n][k01/4 + 1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
+ load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
+ load_ldmatrix(A[n][k01/4 + 1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + tile_A::J), MMQ_MMA_TILE_X_K_Q6_K);
}
#pragma unroll
const int k0 = k00 + k01;
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+ for (int l = 0; l < tile_C::ne/2; ++l) {
+ const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
const int8_t * sc = (const int8_t *) &sc_packed;
}
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+ for (int l = 0; l < tile_C::ne/2; ++l) {
+ const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
}
}
#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
- float tmp[ntx][mma_C::ne] = {{0.0f}};
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+ float tmp[ntx][tile_C::ne] = {{0.0f}};
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
- mma_B B[2];
- float dB[mma_C::ne/2];
+ tile_B B[2];
+ float dB[tile_C::ne/2];
// Here load_generic is faster than load_ldmatrix.
- B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
- B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K);
+ load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
+ load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K);
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int j = j0 + mma_C::get_j(l);
+ for (int l = 0; l < tile_C::ne/2; ++l) {
+ const int j = j0 + tile_C::get_j(l);
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
- mma_C C[2];
- C[0].mma(A[n][k01/4 + 0], B[0]);
- C[1].mma(A[n][k01/4 + 1], B[1]);
+ tile_C C[2];
+ mma(C[0], A[n][k01/4 + 0], B[0]);
+ mma(C[1], A[n][k01/4 + 1], B[1]);
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
+ for (int l = 0; l < tile_C::ne; ++l) {
tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
}
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp[n][l]*dA[n][l/2];
+ for (int l = 0; l < tile_C::ne; ++l) {
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2];
}
}
}
static __device__ __forceinline__ void mmq_write_back_mma(
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
- typedef mma_C_I16J8<int> mma_C;
+ typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
- constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
- const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I);
+ const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
#ifdef NEW_MMA_AVAILABLE
- static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
+ static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
#endif // NEW_MMA_AVAILABLE
#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l);
+ for (int l = 0; l < tile_C::ne; ++l) {
+ const int j = j0 + (threadIdx.y % ntx) * tile_C::J + tile_C::get_j(l);
if (j > j_max) {
continue;
}
- const int i = i0 + n*mma_C::I + mma_C::get_i(l);
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
if (need_check && i > i_max) {
continue;
}
- dst[j*stride + i] = sum[(j0/mma_C::J + n)*mma_C::ne + l];
+ dst[j*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l];
}
}
}