]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
fix potential bug reading model data into a small size optimized string which could...
authorbert hubert <redacted>
Sat, 10 Dec 2022 12:09:31 +0000 (13:09 +0100)
committerGeorgi Gerganov <redacted>
Sat, 10 Dec 2022 14:20:48 +0000 (16:20 +0200)
Also added a small wrapper function to more safely read model data without having to get the sizeof right. I tested this on tiny, base and large models, there was no change in behaviour.

whisper.cpp

index 67451dc80b9b14270f7913431c9e74b5330604a9..2e8ee876ec28a71848b002fc25d0206811bdabcc 100644 (file)
@@ -429,6 +429,12 @@ struct whisper_context {
     int32_t exp_n_audio_ctx; // 0 - use default
 };
 
+template<typename T>
+static void read_safe(std::ifstream& fin, T& dest)
+{
+  fin.read((char*)& dest, sizeof(T));
+}
+
 // load the model from a ggml file
 //
 // file format:
@@ -455,7 +461,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
     // verify magic
     {
         uint32_t magic;
-        fin.read((char *) &magic, sizeof(magic));
+        read_safe(fin, magic);
         if (magic != 0x67676d6c) {
             fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
             return false;
@@ -466,17 +472,17 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
     {
         auto & hparams = model.hparams;
 
-        fin.read((char *) &hparams.n_vocab,       sizeof(hparams.n_vocab));
-        fin.read((char *) &hparams.n_audio_ctx,   sizeof(hparams.n_audio_ctx));
-        fin.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
-        fin.read((char *) &hparams.n_audio_head,  sizeof(hparams.n_audio_head));
-        fin.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
-        fin.read((char *) &hparams.n_text_ctx,    sizeof(hparams.n_text_ctx));
-        fin.read((char *) &hparams.n_text_state,  sizeof(hparams.n_text_state));
-        fin.read((char *) &hparams.n_text_head,   sizeof(hparams.n_text_head));
-        fin.read((char *) &hparams.n_text_layer,  sizeof(hparams.n_text_layer));
-        fin.read((char *) &hparams.n_mels,        sizeof(hparams.n_mels));
-        fin.read((char *) &hparams.f16,           sizeof(hparams.f16));
+        read_safe(fin, hparams.n_vocab);
+        read_safe(fin, hparams.n_audio_ctx);
+        read_safe(fin, hparams.n_audio_state);
+        read_safe(fin, hparams.n_audio_head);
+        read_safe(fin, hparams.n_audio_layer);
+        read_safe(fin, hparams.n_text_ctx);
+        read_safe(fin, hparams.n_text_state);
+        read_safe(fin, hparams.n_text_head);
+        read_safe(fin, hparams.n_text_layer);
+        read_safe(fin, hparams.n_mels);
+        read_safe(fin, hparams.f16);
 
         assert(hparams.n_text_state == hparams.n_audio_state);
 
@@ -524,8 +530,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
     {
         auto & filters = wctx.model.filters;
 
-        fin.read((char *) &filters.n_mel, sizeof(filters.n_mel));
-        fin.read((char *) &filters.n_fft, sizeof(filters.n_fft));
+        read_safe(fin, filters.n_mel);
+        read_safe(fin, filters.n_fft);
 
         filters.data.resize(filters.n_mel * filters.n_fft);
         fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float));
@@ -534,7 +540,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
     // load vocab
     {
         int32_t n_vocab = 0;
-        fin.read((char *) &n_vocab, sizeof(n_vocab));
+        read_safe(fin, n_vocab);
 
         //if (n_vocab != model.hparams.n_vocab) {
         //    fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
@@ -545,10 +551,11 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
         std::string word;
         for (int i = 0; i < n_vocab; i++) {
             uint32_t len;
-            fin.read((char *) &len, sizeof(len));
+            read_safe(fin, len);
 
-            word.resize(len);
-            fin.read((char *) word.data(), len);
+            std::vector<char> tmp(len); // create a buffer
+            fin.read( &tmp[0], tmp.size() ); // read to buffer
+            word.assign(&tmp[0], tmp.size());
 
             vocab.token_to_id[word] = i;
             vocab.id_to_token[i] = word;
@@ -998,9 +1005,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
             int32_t length;
             int32_t ftype;
 
-            fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
-            fin.read(reinterpret_cast<char *>(&length), sizeof(length));
-            fin.read(reinterpret_cast<char *>(&ftype),  sizeof(ftype));
+            read_safe(fin, n_dims);
+            read_safe(fin, length);
+            read_safe(fin, ftype);
 
             if (fin.eof()) {
                 break;
@@ -1009,12 +1016,14 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
             int32_t nelements = 1;
             int32_t ne[3] = { 1, 1, 1 };
             for (int i = 0; i < n_dims; ++i) {
-                fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+                read_safe(fin, ne[i]);
                 nelements *= ne[i];
             }
 
-            std::string name(length, 0);
-            fin.read(&name[0], length);
+            std::string name;
+            std::vector<char> tmp(length); // create a buffer
+            fin.read( &tmp[0], tmp.size() ); // read to buffer
+            name.assign(&tmp[0], tmp.size());
 
             if (model.tensors.find(name.data()) == model.tensors.end()) {
                 fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());