]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
finetune : fix #3404 (#3437)
authorxaedes <redacted>
Mon, 2 Oct 2023 13:15:45 +0000 (15:15 +0200)
committerGitHub <redacted>
Mon, 2 Oct 2023 13:15:45 +0000 (16:15 +0300)
the shapes for init model of gqa models was wrong

examples/finetune/finetune.cpp

index 8ca1874dafc7e98c99c1c536b18b3577468516a1..9ae4bc1981bde802bc8d09b9ac85def2adbe7d95 100644 (file)
@@ -332,8 +332,8 @@ static void init_model(struct llama_model * input, struct my_llama_model * model
 
         assert_shape_1d(layer.attention_norm, hparams.n_embd);
         assert_shape_2d(layer.wq,             hparams.n_embd, hparams.n_embd);
-        assert_shape_2d(layer.wk,             hparams.n_embd, hparams.n_embd);
-        assert_shape_2d(layer.wv,             hparams.n_embd, hparams.n_embd);
+        assert_shape_2d(layer.wk,             hparams.n_embd, hparams.n_embd_gqa());
+        assert_shape_2d(layer.wv,             hparams.n_embd, hparams.n_embd_gqa());
         assert_shape_2d(layer.wo,             hparams.n_embd, hparams.n_embd);
         assert_shape_1d(layer.ffn_norm,       hparams.n_embd);
         assert_shape_2d(layer.w1,             hparams.n_embd, hparams.n_ff);