]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: fix q_nope_absorbed prec for DS 2 Lite f16 (#13137)
authorJohannes Gäßler <redacted>
Mon, 28 Apr 2025 07:29:26 +0000 (09:29 +0200)
committerGitHub <redacted>
Mon, 28 Apr 2025 07:29:26 +0000 (09:29 +0200)
ggml/include/ggml.h
ggml/src/ggml-cuda/ggml-cuda.cu
src/llama-model.cpp

index 51aa5b3a0ab44309a1c6dbc0de0ebdd438b8e6c1..1b8603e78e55348c9fcd4db8c39451a55b47c295 100644 (file)
@@ -393,8 +393,8 @@ extern "C" {
 
     // precision
     enum ggml_prec {
-        GGML_PREC_DEFAULT,
-        GGML_PREC_F32,
+        GGML_PREC_DEFAULT =  0, // stored as ggml_tensor.op_params, 0 by default
+        GGML_PREC_F32     = 10,
     };
 
     // model file types
index e0e0d2137f3be28ff70304b5caf14c1e257142de..19b9ce7231aa29086b10077961cbe54baa7f0aca 100644 (file)
@@ -1935,8 +1935,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
         ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
     } else if (!split && use_mul_mat_vec_q) {
         ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
-    } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
-               && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
+    } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
+            dst->op_params[0] == GGML_PREC_DEFAULT && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
         // general KQ + KQV multi-batch without FlashAttention
         ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
     } else if (use_mul_mat_vec) {
index 6b7bfecf3a1cf7f9d3172fcbbbcd1ea380cd1e3d..df2791002e9f9fcea66c413bc6a636743fecf04d 100644 (file)
@@ -10149,6 +10149,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
 
                     // {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head}
                     ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope);
+                    ggml_mul_mat_set_prec(q_nope_absorbed, GGML_PREC_F32);
                     cb(q_nope_absorbed, "q_nope_absorbed", il);
 
                     // {kv_lora_rank, n_head, n_tokens}