]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add support for lora adapters in T5 model (#8938)
authorfairydreaming <redacted>
Fri, 9 Aug 2024 16:53:09 +0000 (18:53 +0200)
committerGitHub <redacted>
Fri, 9 Aug 2024 16:53:09 +0000 (18:53 +0200)
Co-authored-by: Stanisław Szymczyk <redacted>
src/llama.cpp

index decdcebdbb6c91d13c5c1b9a5cd701024164abd7..97dd1b3fea4b9444d4a6a1b714091a706613c412 100644 (file)
@@ -13167,13 +13167,13 @@ struct llm_build_context {
 
                 // self-attention
                 {
-                    struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq_enc, cur);
+                    struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq_enc, cur);
                     cb(Qcur, "Qcur", il);
 
-                    struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk_enc, cur);
+                    struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk_enc, cur);
                     cb(Kcur, "Kcur", il);
 
-                    struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv_enc, cur);
+                    struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv_enc, cur);
                     cb(Vcur, "Vcur", il);
 
                     Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
@@ -13207,7 +13207,7 @@ struct llm_build_context {
 
                     ggml_build_forward_expand(gf, cur);
 
-                    cur = ggml_mul_mat(ctx0, model.layers[il].wo_enc, cur);
+                    cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo_enc, cur);
                     cb(cur, "kqv_out", il);
                 }
 
@@ -13281,13 +13281,13 @@ struct llm_build_context {
 
                 // self-attention
                 {
-                    struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                    struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                     cb(Qcur, "Qcur", il);
 
-                    struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                    struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                     cb(Kcur, "Kcur", il);
 
-                    struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                    struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                     cb(Vcur, "Vcur", il);
 
                     llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
@@ -13334,7 +13334,7 @@ struct llm_build_context {
 
                     ggml_build_forward_expand(gf, cur);
 
-                    cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
+                    cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
                     cb(cur, "kqv_out", il);
                 }
 
@@ -13351,13 +13351,13 @@ struct llm_build_context {
 
                 // cross-attention
                 {
-                    struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq_cross, cur);
+                    struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq_cross, cur);
                     cb(Qcur, "Qcur", il);
 
-                    struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk_cross, embd_enc);
+                    struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk_cross, embd_enc);
                     cb(Kcur, "Kcur", il);
 
-                    struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv_cross, embd_enc);
+                    struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv_cross, embd_enc);
                     cb(Vcur, "Vcur", il);
 
                     Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
@@ -13386,7 +13386,7 @@ struct llm_build_context {
 
                     ggml_build_forward_expand(gf, cur);
 
-                    cur = ggml_mul_mat(ctx0, model.layers[il].wo_cross, cur);
+                    cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo_cross, cur);
                     cb(cur, "kqv_out", il);
                 }
 
@@ -13443,7 +13443,7 @@ struct llm_build_context {
             cb(cur, "result_norm", -1);
 
             // lm_head
-            cur = ggml_mul_mat(ctx0, model.output, cur);
+            cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
             cb(cur, "result_output", -1);
         }