]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: use switch statements in constexpr functions (#13095)
authorJohannes Gäßler <redacted>
Thu, 24 Apr 2025 13:57:10 +0000 (15:57 +0200)
committerGitHub <redacted>
Thu, 24 Apr 2025 13:57:10 +0000 (15:57 +0200)
ggml/src/ggml-cuda/mmq.cuh
ggml/src/ggml-cuda/mmvq.cu

index 532358018f410c30f65e4c9926e0d1068d08f46e..3cb2015520ba14c99eb399d81af71666435c17af 100644 (file)
@@ -155,25 +155,27 @@ static constexpr __device__ int get_mmq_y_device() {
 #define MMQ_DP4A_TXS_Q6_K    tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K   + mmq_y/QI6_K,     mmq_y*WARP_SIZE/8 + mmq_y/8}
 
 static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
-    return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
-        type == GGML_TYPE_Q4_1    ? MMQ_DP4A_TXS_Q4_1 :
-        type == GGML_TYPE_Q5_0    ? MMQ_DP4A_TXS_Q8_0 :
-        type == GGML_TYPE_Q5_1    ? MMQ_DP4A_TXS_Q8_1 :
-        type == GGML_TYPE_Q8_0    ? MMQ_DP4A_TXS_Q8_0 :
-        type == GGML_TYPE_Q2_K    ? MMQ_DP4A_TXS_Q2_K :
-        type == GGML_TYPE_Q3_K    ? MMQ_DP4A_TXS_Q3_K :
-        type == GGML_TYPE_Q4_K    ? MMQ_DP4A_TXS_Q4_K :
-        type == GGML_TYPE_Q5_K    ? MMQ_DP4A_TXS_Q5_K :
-        type == GGML_TYPE_Q6_K    ? MMQ_DP4A_TXS_Q6_K :
-        type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 :
-        type == GGML_TYPE_IQ2_XS  ? MMQ_DP4A_TXS_Q8_0_16 :
-        type == GGML_TYPE_IQ2_S   ? MMQ_DP4A_TXS_Q8_0_16 :
-        type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 :
-        type == GGML_TYPE_IQ3_S   ? MMQ_DP4A_TXS_Q8_0 :
-        type == GGML_TYPE_IQ1_S   ? MMQ_DP4A_TXS_Q8_0 :
-        type == GGML_TYPE_IQ4_XS  ? MMQ_DP4A_TXS_Q8_0 :
-        type == GGML_TYPE_IQ4_NL  ? MMQ_DP4A_TXS_Q8_0 :
-        tile_x_sizes{0, 0, 0};
+    switch (type) {
+        case GGML_TYPE_Q4_0:    return MMQ_DP4A_TXS_Q4_0;
+        case GGML_TYPE_Q4_1:    return MMQ_DP4A_TXS_Q4_1;
+        case GGML_TYPE_Q5_0:    return MMQ_DP4A_TXS_Q8_0;
+        case GGML_TYPE_Q5_1:    return MMQ_DP4A_TXS_Q8_1;
+        case GGML_TYPE_Q8_0:    return MMQ_DP4A_TXS_Q8_0;
+        case GGML_TYPE_Q2_K:    return MMQ_DP4A_TXS_Q2_K;
+        case GGML_TYPE_Q3_K:    return MMQ_DP4A_TXS_Q3_K;
+        case GGML_TYPE_Q4_K:    return MMQ_DP4A_TXS_Q4_K;
+        case GGML_TYPE_Q5_K:    return MMQ_DP4A_TXS_Q5_K;
+        case GGML_TYPE_Q6_K:    return MMQ_DP4A_TXS_Q6_K;
+        case GGML_TYPE_IQ2_XXS: return MMQ_DP4A_TXS_Q8_0;
+        case GGML_TYPE_IQ2_XS:  return MMQ_DP4A_TXS_Q8_0_16;
+        case GGML_TYPE_IQ2_S:   return MMQ_DP4A_TXS_Q8_0_16;
+        case GGML_TYPE_IQ3_XXS: return MMQ_DP4A_TXS_Q8_0;
+        case GGML_TYPE_IQ3_S:   return MMQ_DP4A_TXS_Q8_0;
+        case GGML_TYPE_IQ1_S:   return MMQ_DP4A_TXS_Q8_0;
+        case GGML_TYPE_IQ4_XS:  return MMQ_DP4A_TXS_Q8_0;
+        case GGML_TYPE_IQ4_NL:  return MMQ_DP4A_TXS_Q8_0;
+        default:                return tile_x_sizes{0, 0, 0};
+    }
 }
 
 #define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0                 + 4)
@@ -189,25 +191,27 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
 static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
 
 static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
-    return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
-        type == GGML_TYPE_Q4_1    ? MMQ_MMA_TILE_X_K_Q8_1 :
-        type == GGML_TYPE_Q5_0    ? MMQ_MMA_TILE_X_K_Q8_0 :
-        type == GGML_TYPE_Q5_1    ? MMQ_MMA_TILE_X_K_Q8_1 :
-        type == GGML_TYPE_Q8_0    ? MMQ_MMA_TILE_X_K_Q8_0 :
-        type == GGML_TYPE_Q2_K    ? MMQ_MMA_TILE_X_K_Q2_K :
-        type == GGML_TYPE_Q3_K    ? MMQ_MMA_TILE_X_K_Q3_K :
-        type == GGML_TYPE_Q4_K    ? MMQ_MMA_TILE_X_K_Q8_1 :
-        type == GGML_TYPE_Q5_K    ? MMQ_MMA_TILE_X_K_Q8_1 :
-        type == GGML_TYPE_Q6_K    ? MMQ_MMA_TILE_X_K_Q6_K :
-        type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
-        type == GGML_TYPE_IQ2_XS  ? MMQ_MMA_TILE_X_K_Q3_K :
-        type == GGML_TYPE_IQ2_S   ? MMQ_MMA_TILE_X_K_Q3_K :
-        type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
-        type == GGML_TYPE_IQ3_S   ? MMQ_MMA_TILE_X_K_Q8_0 :
-        type == GGML_TYPE_IQ1_S   ? MMQ_MMA_TILE_X_K_Q8_0 :
-        type == GGML_TYPE_IQ4_XS  ? MMQ_MMA_TILE_X_K_Q8_0 :
-        type == GGML_TYPE_IQ4_NL  ? MMQ_MMA_TILE_X_K_Q8_0 :
-        0;
+    switch (type) {
+        case GGML_TYPE_Q4_0:    return MMQ_MMA_TILE_X_K_Q8_0;
+        case GGML_TYPE_Q4_1:    return MMQ_MMA_TILE_X_K_Q8_1;
+        case GGML_TYPE_Q5_0:    return MMQ_MMA_TILE_X_K_Q8_0;
+        case GGML_TYPE_Q5_1:    return MMQ_MMA_TILE_X_K_Q8_1;
+        case GGML_TYPE_Q8_0:    return MMQ_MMA_TILE_X_K_Q8_0;
+        case GGML_TYPE_Q2_K:    return MMQ_MMA_TILE_X_K_Q2_K;
+        case GGML_TYPE_Q3_K:    return MMQ_MMA_TILE_X_K_Q3_K;
+        case GGML_TYPE_Q4_K:    return MMQ_MMA_TILE_X_K_Q8_1;
+        case GGML_TYPE_Q5_K:    return MMQ_MMA_TILE_X_K_Q8_1;
+        case GGML_TYPE_Q6_K:    return MMQ_MMA_TILE_X_K_Q6_K;
+        case GGML_TYPE_IQ2_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
+        case GGML_TYPE_IQ2_XS:  return MMQ_MMA_TILE_X_K_Q3_K;
+        case GGML_TYPE_IQ2_S:   return MMQ_MMA_TILE_X_K_Q3_K;
+        case GGML_TYPE_IQ3_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
+        case GGML_TYPE_IQ3_S:   return MMQ_MMA_TILE_X_K_Q8_0;
+        case GGML_TYPE_IQ1_S:   return MMQ_MMA_TILE_X_K_Q8_0;
+        case GGML_TYPE_IQ4_XS:  return MMQ_MMA_TILE_X_K_Q8_0;
+        case GGML_TYPE_IQ4_NL:  return MMQ_MMA_TILE_X_K_Q8_0;
+        default:                return 0;
+    }
 }
 
 #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
index cac04916cd8f0b5970f54f95b5e972b919a92ebe..d846e35a6a26d87848e314de7877ba35c30f05e6 100644 (file)
@@ -7,47 +7,51 @@
 typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
 
 static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
-    return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 :
-        type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 :
-        type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 :
-        type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 :
-        type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 :
-        type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 :
-        type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 :
-        type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 :
-        type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 :
-        type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 :
-        type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 :
-        type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 :
-        type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 :
-        type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 :
-        type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 :
-        type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 :
-        type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 :
-        type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 :
-        type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 :
-        nullptr;
+    switch (type) {
+        case GGML_TYPE_Q4_0:    return vec_dot_q4_0_q8_1;
+        case GGML_TYPE_Q4_1:    return vec_dot_q4_1_q8_1;
+        case GGML_TYPE_Q5_0:    return vec_dot_q5_0_q8_1;
+        case GGML_TYPE_Q5_1:    return vec_dot_q5_1_q8_1;
+        case GGML_TYPE_Q8_0:    return vec_dot_q8_0_q8_1;
+        case GGML_TYPE_Q2_K:    return vec_dot_q2_K_q8_1;
+        case GGML_TYPE_Q3_K:    return vec_dot_q3_K_q8_1;
+        case GGML_TYPE_Q4_K:    return vec_dot_q4_K_q8_1;
+        case GGML_TYPE_Q5_K:    return vec_dot_q5_K_q8_1;
+        case GGML_TYPE_Q6_K:    return vec_dot_q6_K_q8_1;
+        case GGML_TYPE_IQ2_XXS: return vec_dot_iq2_xxs_q8_1;
+        case GGML_TYPE_IQ2_XS:  return vec_dot_iq2_xs_q8_1;
+        case GGML_TYPE_IQ2_S:   return vec_dot_iq2_s_q8_1;
+        case GGML_TYPE_IQ3_XXS: return vec_dot_iq3_xxs_q8_1;
+        case GGML_TYPE_IQ1_S:   return vec_dot_iq1_s_q8_1;
+        case GGML_TYPE_IQ1_M:   return vec_dot_iq1_m_q8_1;
+        case GGML_TYPE_IQ4_NL:  return vec_dot_iq4_nl_q8_1;
+        case GGML_TYPE_IQ4_XS:  return vec_dot_iq4_xs_q8_1;
+        case GGML_TYPE_IQ3_S:   return vec_dot_iq3_s_q8_1;
+        default:                return nullptr;
+    }
 }
 
 static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
-    return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ :
-        type == GGML_TYPE_Q4_1    ? VDR_Q4_1_Q8_1_MMVQ :
-        type == GGML_TYPE_Q5_0    ? VDR_Q5_0_Q8_1_MMVQ :
-        type == GGML_TYPE_Q5_1    ? VDR_Q5_1_Q8_1_MMVQ :
-        type == GGML_TYPE_Q8_0    ? VDR_Q8_0_Q8_1_MMVQ :
-        type == GGML_TYPE_Q2_K    ? VDR_Q2_K_Q8_1_MMVQ :
-        type == GGML_TYPE_Q3_K    ? VDR_Q3_K_Q8_1_MMVQ :
-        type == GGML_TYPE_Q4_K    ? VDR_Q4_K_Q8_1_MMVQ :
-        type == GGML_TYPE_Q5_K    ? VDR_Q5_K_Q8_1_MMVQ :
-        type == GGML_TYPE_Q6_K    ? VDR_Q6_K_Q8_1_MMVQ :
-        type == GGML_TYPE_IQ2_XXS ? VDR_IQ2_XXS_Q8_1_MMVQ :
-        type == GGML_TYPE_IQ2_XS  ? VDR_IQ2_XS_Q8_1_MMVQ :
-        type == GGML_TYPE_IQ2_S   ? VDR_IQ2_S_Q8_1_MMVQ :
-        type == GGML_TYPE_IQ3_XXS ? VDR_IQ3_XXS_Q8_1_MMVQ :
-        type == GGML_TYPE_IQ3_S   ? VDR_IQ3_S_Q8_1_MMVQ :
-        type == GGML_TYPE_IQ4_NL  ? VDR_IQ4_NL_Q8_1_MMVQ :
-        type == GGML_TYPE_IQ4_XS  ? VDR_IQ4_XS_Q8_1_MMVQ :
-        1;
+    switch (type) {
+        case GGML_TYPE_Q4_0:    return VDR_Q4_0_Q8_1_MMVQ;
+        case GGML_TYPE_Q4_1:    return VDR_Q4_1_Q8_1_MMVQ;
+        case GGML_TYPE_Q5_0:    return VDR_Q5_0_Q8_1_MMVQ;
+        case GGML_TYPE_Q5_1:    return VDR_Q5_1_Q8_1_MMVQ;
+        case GGML_TYPE_Q8_0:    return VDR_Q8_0_Q8_1_MMVQ;
+        case GGML_TYPE_Q2_K:    return VDR_Q2_K_Q8_1_MMVQ;
+        case GGML_TYPE_Q3_K:    return VDR_Q3_K_Q8_1_MMVQ;
+        case GGML_TYPE_Q4_K:    return VDR_Q4_K_Q8_1_MMVQ;
+        case GGML_TYPE_Q5_K:    return VDR_Q5_K_Q8_1_MMVQ;
+        case GGML_TYPE_Q6_K:    return VDR_Q6_K_Q8_1_MMVQ;
+        case GGML_TYPE_IQ2_XXS: return VDR_IQ2_XXS_Q8_1_MMVQ;
+        case GGML_TYPE_IQ2_XS:  return VDR_IQ2_XS_Q8_1_MMVQ;
+        case GGML_TYPE_IQ2_S:   return VDR_IQ2_S_Q8_1_MMVQ;
+        case GGML_TYPE_IQ3_XXS: return VDR_IQ3_XXS_Q8_1_MMVQ;
+        case GGML_TYPE_IQ3_S:   return VDR_IQ3_S_Q8_1_MMVQ;
+        case GGML_TYPE_IQ4_NL:  return VDR_IQ4_NL_Q8_1_MMVQ;
+        case GGML_TYPE_IQ4_XS:  return VDR_IQ4_XS_Q8_1_MMVQ;
+        default:                return 1;
+    }
 }
 
 enum mmvq_parameter_table_id {