// Host function: returns the max batch size for the current arch+type at runtime.
int get_mmvq_mmid_max_batch(ggml_type type, int cc) {
// NVIDIA: Volta, Ada Lovelace, and Blackwell always use MMVQ for MUL_MAT_ID.
- if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) {
- return MMVQ_MAX_BATCH_SIZE;
- }
- if (cc >= GGML_CUDA_CC_TURING) {
- return get_mmvq_mmid_max_batch_turing_plus(type);
- }
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
+ if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) {
+ return MMVQ_MAX_BATCH_SIZE;
+ }
+ if (cc >= GGML_CUDA_CC_TURING) {
+ return get_mmvq_mmid_max_batch_turing_plus(type);
+ }
return get_mmvq_mmid_max_batch_pascal_older(type);
}
+
// AMD
- if (GGML_CUDA_CC_IS_RDNA4(cc)) {
- return get_mmvq_mmid_max_batch_rdna4(type);
- }
- if (GGML_CUDA_CC_IS_RDNA3(cc)) {
- return get_mmvq_mmid_max_batch_rdna3(type);
- }
- if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) {
- return get_mmvq_mmid_max_batch_rdna1_rdna2(type);
- }
- if (GGML_CUDA_CC_IS_CDNA(cc)) {
- return get_mmvq_mmid_max_batch_cdna(type);
- }
- if (GGML_CUDA_CC_IS_GCN(cc)) {
- return get_mmvq_mmid_max_batch_gcn(type);
+ if (GGML_CUDA_CC_IS_AMD(cc)) {
+ if (GGML_CUDA_CC_IS_RDNA4(cc)) {
+ return get_mmvq_mmid_max_batch_rdna4(type);
+ }
+ if (GGML_CUDA_CC_IS_RDNA3(cc)) {
+ return get_mmvq_mmid_max_batch_rdna3(type);
+ }
+ if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) {
+ return get_mmvq_mmid_max_batch_rdna1_rdna2(type);
+ }
+ if (GGML_CUDA_CC_IS_CDNA(cc)) {
+ return get_mmvq_mmid_max_batch_cdna(type);
+ }
+ if (GGML_CUDA_CC_IS_GCN(cc)) {
+ return get_mmvq_mmid_max_batch_gcn(type);
+ }
}
return MMVQ_MAX_BATCH_SIZE;
}