const int gqa_ratio = Q->ne[2] / K->ne[2];
if (gqa_ratio == 20) { // GLM 4.7 Flash
if (cc >= GGML_CUDA_CC_BLACKWELL) {
+ if (Q->ne[1] <= 4 && K->ne[1] >= 65536) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+ break;
+ }
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
break;
}
}
if (cc >= GGML_CUDA_CC_TURING) {
if (Q->ne[1] <= 4) {
+ if (K->ne[1] <= 16384) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+ break;
+ }
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 32>(ctx, dst);
break;
}