assert_shape_1d(layer.attention_norm, hparams.n_embd);
assert_shape_2d(layer.wq, hparams.n_embd, hparams.n_embd);
- assert_shape_2d(layer.wk, hparams.n_embd, hparams.n_embd);
- assert_shape_2d(layer.wv, hparams.n_embd, hparams.n_embd);
+ assert_shape_2d(layer.wk, hparams.n_embd, hparams.n_embd_gqa());
+ assert_shape_2d(layer.wv, hparams.n_embd, hparams.n_embd_gqa());
assert_shape_2d(layer.wo, hparams.n_embd, hparams.n_embd);
assert_shape_1d(layer.ffn_norm, hparams.n_embd);
assert_shape_2d(layer.w1, hparams.n_embd, hparams.n_ff);