const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
- const int ne00,
- const int ne01,
- const int ne02,
- const int ne03,
- const int ne10,
- const int ne11,
- const int ne12,
- const int ne13,
- const int ne31,
- const int ne32,
- const int ne33,
- const int nb31,
- const int nb32,
- const int nb33,
- const int nb01,
- const int nb02,
- const int nb03,
- const int nb11,
- const int nb12,
- const int nb13,
- const int nb21,
- const int nb22,
- const int nb23,
- const int ne0,
- const int ne1,
- const int ne2,
- const int ne3);
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
+ const int32_t nb31, const int32_t nb32, const int64_t nb33);
typedef half (*vec_dot_KQ_f16_t)(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
mask ? ((const char *) mask->data) : nullptr,
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
- Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
- K->ne[0], K->ne[1], K->ne[2], K->ne[3],
- mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
- mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0,
- Q->nb[1], Q->nb[2], Q->nb[3],
- nb11, nb12, nb13,
+ Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
+ K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
nb21, nb22, nb23,
- KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
+ mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
+ mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0
);
CUDA_CHECK(cudaGetLastError());
const int stride_K,
const int stride_V,
const int stride_mask,
- const int jt,
half2 * const __restrict__ tile_Q,
half2 * const __restrict__ tile_K,
half2 * const __restrict__ tile_V,
cp_async_wait_all();
__syncthreads();
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
- (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
+ (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V);
} else {
constexpr bool use_cp_async = nstages == 1;
if (ncols2 > 1 || mask_h2) {
if (nstages <= 1) {
constexpr bool use_cp_async = nstages == 1;
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
- (K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K);
+ (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K);
if (use_cp_async) {
cp_async_wait_all();
}
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
}
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
- (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
+ (K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
}
}
if (nstages <= 1 && i0_start < reusable_cutoff) {
constexpr bool use_cp_async = nstages == 1;
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
- (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
+ (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
if (use_cp_async) {
cp_async_wait_all();
}
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V);
- GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
- GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
+ GGML_UNUSED(stride_mask); GGML_UNUSED(tile_K);
GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
(mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
}
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
- (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
+ (K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
}
// Iterate over ne11 == previous tokens:
constexpr bool last_iter = false;
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
- ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
}
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
constexpr bool last_iter = true;
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
- ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
}
// With multi-stage loading there is no __syncthreads at the end of the iter,
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
- const int ne00,
- const int ne01,
- const int ne02,
- const int ne03,
- const int ne10,
- const int ne11,
- const int ne12,
- const int ne13,
- const int ne31,
- const int ne32,
- const int ne33,
- const int nb31,
- const int nb32,
- const int nb33,
- const int nb01,
- const int nb02,
- const int nb03,
- const int nb11,
- const int nb12,
- const int nb13,
- const int nb21,
- const int nb22,
- const int nb23,
- const int ne0,
- const int ne1,
- const int ne2,
- const int ne3) {
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
// Skip unused kernel variants for faster compilation:
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
- GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
- GGML_UNUSED(ne2); GGML_UNUSED(ne3);
+ GGML_UNUSED(nb22); GGML_UNUSED(nb23);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
}
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
- const int ne00,
- const int ne01,
- const int ne02,
- const int ne03,
- const int ne10,
- const int ne11,
- const int ne12,
- const int ne13,
- const int ne31,
- const int ne32,
- const int ne33,
- const int nb31,
- const int nb32,
- const int nb33,
- const int nb01,
- const int nb02,
- const int nb03,
- const int nb11,
- const int nb12,
- const int nb13,
- const int nb21,
- const int nb22,
- const int nb23,
- const int ne0,
- const int ne1,
- const int ne2,
- const int ne3) {
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
// Skip unused kernel variants for faster compilation:
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
const int k_KQ = k_KQ_0 + threadIdx.x;
- KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
+ KV_tmp[i_KQ][k_KQ] = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
}
}
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
- KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i];
+ KV_tmp[k][i] = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
}
}
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
- GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
- GGML_UNUSED(ne2); GGML_UNUSED(ne3);
+ GGML_UNUSED(nb23);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
}
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
- const int ne00,
- const int ne01,
- const int ne02,
- const int ne03,
- const int ne10,
- const int ne11,
- const int ne12,
- const int ne13,
- const int ne31,
- const int ne32,
- const int ne33,
- const int nb31,
- const int nb32,
- const int nb33,
- const int nb01,
- const int nb02,
- const int nb03,
- const int nb11,
- const int nb12,
- const int nb13,
- const int nb21,
- const int nb22,
- const int nb23,
- const int ne0,
- const int ne1,
- const int ne2,
- const int ne3) {
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#ifdef FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
- GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
- GGML_UNUSED(ne2); GGML_UNUSED(ne3);
+ GGML_UNUSED(nb23);
NO_DEVICE_CODE;
return;
}
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
- const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
+ const half2 tmp = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp);
KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
}
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
- KV_tmp2[k*(D/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
- KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
+ const half2 tmp = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
+ KV_tmp2[k*(D/2) + i].x = __low2float(tmp);
+ KV_tmp2[k*(D/2) + i].y = __high2float(tmp);
}
}
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
- GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE
}
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
- const int ne00,
- const int ne01,
- const int ne02,
- const int ne03,
- const int ne10,
- const int ne11,
- const int ne12,
- const int ne13,
- const int ne31,
- const int ne32,
- const int ne33,
- const int nb31,
- const int nb32,
- const int nb33,
- const int nb01,
- const int nb02,
- const int nb03,
- const int nb11,
- const int nb12,
- const int nb13,
- const int nb21,
- const int nb22,
- const int nb23,
- const int ne0,
- const int ne1,
- const int ne2,
- const int ne3) {
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
// Skip unused kernel variants for faster compilation:
half2 VKQ[ncols] = {{0.0f, 0.0f}};
+ K += blockIdx.y*D * nb11;
+ V += blockIdx.y*D * nb21;
+ maskh += blockIdx.y*D;
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
// Calculate KQ tile and keep track of new maximum KQ values:
if (mask) {
#pragma unroll
for (int j = 0; j < ncols; ++j) {
- maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + k_VKQ_0 + tid];
+ maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid];
}
__syncthreads();
#pragma unroll
for (int j = 0; j < ncols; ++j) {
- half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
+ half sum = vec_dot_KQ(K + i_KQ*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
sum = warp_reduce_sum((float)sum);
if (use_logit_softcap) {
}
half2 V_k;
- reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid);
- reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid);
+ reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid);
+ reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid);
#pragma unroll
for (int j = 0; j < ncols; ++j) {
VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
}
}
+ K += gridDim.y*D * nb11;
+ V += gridDim.y*D * nb21;
+ maskh += gridDim.y*D;
+
__syncthreads();
}
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
- GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
- GGML_UNUSED(ne2); GGML_UNUSED(ne3);
+ GGML_UNUSED(nb23);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
}
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
- const int ne00,
- const int ne01,
- const int ne02,
- const int ne03,
- const int ne10,
- const int ne11,
- const int ne12,
- const int ne13,
- const int ne31,
- const int ne32,
- const int ne33,
- const int nb31,
- const int nb32,
- const int nb33,
- const int nb01,
- const int nb02,
- const int nb03,
- const int nb11,
- const int nb12,
- const int nb13,
- const int nb21,
- const int nb22,
- const int nb23,
- const int ne0,
- const int ne1,
- const int ne2,
- const int ne3) {
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#ifdef FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
- GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
- GGML_UNUSED(ne2); GGML_UNUSED(ne3);
+ GGML_UNUSED(nb23);
NO_DEVICE_CODE;
return;
}
float VKQ[ncols] = {0.0f};
+ K += blockIdx.y*D * nb11;
+ V += blockIdx.y*D * nb21;
+ maskh += blockIdx.y*D;
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
// Calculate KQ tile and keep track of new maximum KQ values:
if (mask) {
#pragma unroll
for (int j = 0; j < ncols; ++j) {
- maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + k_VKQ_0 + tid]);
+ maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + tid]);
}
__syncthreads();
#pragma unroll
for (int j = 0; j < ncols; ++j) {
- float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
+ float sum = vec_dot_KQ(K + i_KQ*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
sum = warp_reduce_sum(sum);
if (use_logit_softcap) {
break;
}
- const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid);
+ const float V_ki = dequantize_1_v(V + k*nb21, tid);
#pragma unroll
for (int j = 0; j < ncols; ++j) {
VKQ[j] += V_ki*KQ[j*D + k];
}
}
+ K += gridDim.y*D * nb11;
+ V += gridDim.y*D * nb21;
+ maskh += gridDim.y*D;
+
__syncthreads();
}
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
- GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE
}
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
- const int ne00,
- const int ne01,
- const int ne02,
- const int ne03,
- const int ne10,
- const int ne11,
- const int ne12,
- const int ne13,
- const int ne31,
- const int ne32,
- const int ne33,
- const int nb31,
- const int nb32,
- const int nb33,
- const int nb01,
- const int nb02,
- const int nb03,
- const int nb11,
- const int nb12,
- const int nb13,
- const int nb21,
- const int nb22,
- const int nb23,
- const int ne0,
- const int ne1,
- const int ne2,
- const int ne3) {
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
frag_a_K K_a;
- wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
+ wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
frag_a_V v_a;
- wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
+ wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
- GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
}
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
- if (GGML_CUDA_CC_IS_AMD(cc)) {
#if defined(GGML_HIP_ROCWMMA_FATTN)
- if (fp16_mma_available(cc)) {
- ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
- return;
- }
-#endif // defined(GGML_HIP_ROCWMMA_FATTN)
-
- // On AMD the tile kernels perform poorly, use the vec kernel instead:
- if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
- ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
- } else {
- ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
- }
+ if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
+ ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
return;
}
+#endif // defined(GGML_HIP_ROCWMMA_FATTN)
if (!fast_fp16_available(cc)) {
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {