ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
break;
// case 256:
- // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
+ // ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
// break;
default:
GGML_ABORT("fatal error");
return;
}
- if (!new_mma_available(cc)) {
+ if (!fp16_mma_available(cc)) {
if (prec == GGML_PREC_DEFAULT) {
if (Q->ne[1] <= 8) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
if (cc == GGML_CUDA_CC_VOLTA) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
+ return;
}
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);