]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : do not use F16 tensors when in F32 mode (#369)
authorGeorgi Gerganov <redacted>
Thu, 5 Jan 2023 20:56:25 +0000 (22:56 +0200)
committerGeorgi Gerganov <redacted>
Thu, 5 Jan 2023 20:56:25 +0000 (22:56 +0200)
whisper.cpp

index a61ededd21b8003d1940d7ae9fa36d9e95254a0c..e6d050a914549cecd14e778cfec0b1406ee5ec24 100644 (file)
@@ -412,6 +412,8 @@ struct whisper_context {
     std::vector<uint8_t>   buf_compute;
     std::vector<uint8_t>   buf_compute_layer;
 
+    ggml_type wtype; // weight type (FP32 or FP16)
+
     whisper_model model;
     whisper_vocab vocab;
 
@@ -629,7 +631,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
 
     // for the big tensors, we have the option to store the data in 16-bit floats
     // in order to save memory and also to speed up the computation
-    const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
+    wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
+
+    const ggml_type wtype = wctx.wtype;
 
     size_t ctx_size = 0;
 
@@ -650,7 +654,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
 
         // encoder
         {
-            // TODO: F16 .. maybe not?
             ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
 
             ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype);         // e_conv_1_w
@@ -665,7 +668,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
 
         // decoder
         {
-            // TODO: F16 .. maybe not?
             ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
 
             ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
@@ -982,8 +984,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
             const int n_mem      = n_text_layer*n_text_ctx;
             const int n_elements = n_text_state*n_mem;
 
-            model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
-            model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+            model.memory_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
+            model.memory_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
         }
 
         // key/value memory for the cross-attention layer
@@ -993,8 +995,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
             const int n_mem      = n_text_layer*n_audio_ctx;
             const int n_elements = n_text_state*n_mem;
 
-            model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
-            model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+            model.memory_cross_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
+            model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
         }
 
         const size_t memory_size =
@@ -1240,14 +1242,14 @@ static bool whisper_encode(
                 ggml_permute(ctxL,
                         ggml_cpy(ctxL,
                             Qcur,
-                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
+                            ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
                         0, 2, 1, 3);
 
             struct ggml_tensor * K =
                 ggml_permute(ctxL,
                         ggml_cpy(ctxL,
                             Kcur,
-                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
+                            ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
                         0, 2, 1, 3);
 
             struct ggml_tensor * V =
@@ -1257,7 +1259,7 @@ static bool whisper_encode(
                                 Vcur,
                                 n_state/n_head, n_head, n_ctx),
                             1, 2, 0, 3),
-                        ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head)
+                        ggml_new_tensor_3d(ctxL, wctx.wtype, n_ctx, n_state/n_head, n_head)
                         );
 
             struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
@@ -1273,7 +1275,7 @@ static bool whisper_encode(
                 ggml_permute(ctxL,
                         ggml_cpy(ctxL,
                             Kcur,
-                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
+                            ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
                         0, 2, 1, 3);
 
             // K * Q
@@ -1291,7 +1293,7 @@ static bool whisper_encode(
             //    ggml_permute(ctxL,
             //            ggml_cpy(ctxL,
             //                Vcur,
-            //                ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
+            //                ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
             //            1, 2, 0, 3);
 
             //struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
@@ -1303,7 +1305,7 @@ static bool whisper_encode(
                                 Vcur,
                                 n_state/n_head, n_head, n_ctx),
                             0, 2, 1, 3),
-                        ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head)
+                        ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_ctx, n_head)
                         );
 
             struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
@@ -1348,7 +1350,7 @@ static bool whisper_encode(
 
 #ifdef USE_FLASH_FF
             cur = ggml_flash_ff(ctxL,
-                    ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, GGML_TYPE_F16, n_state, N)),
+                    ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, wctx.wtype, n_state, N)),
                     layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
 #else
             // fully connected
@@ -3156,7 +3158,7 @@ int whisper_full_parallel(
 
         // separate key + value memory for each processor
         {
-            auto & ctx = model.ctx_mem;
+            auto & mctx = model.ctx_mem;
 
             const auto & hparams = model.hparams;
 
@@ -3169,8 +3171,8 @@ int whisper_full_parallel(
                 const int n_mem      = n_text_layer*n_text_ctx;
                 const int n_elements = n_text_state*n_mem;
 
-                model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
-                model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+                model.memory_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
+                model.memory_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
             }
 
             // key/value memory for the cross-attention layer
@@ -3180,8 +3182,8 @@ int whisper_full_parallel(
                 const int n_mem      = n_text_layer*n_audio_ctx;
                 const int n_elements = n_text_state*n_mem;
 
-                model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
-                model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+                model.memory_cross_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
+                model.memory_cross_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
             }
         }
     }