]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
whisper : various fixes
authorGeorgi Gerganov <redacted>
Mon, 3 Oct 2022 16:31:17 +0000 (19:31 +0300)
committerGeorgi Gerganov <redacted>
Mon, 3 Oct 2022 16:31:17 +0000 (19:31 +0300)
examples/whisper/main.cpp

index 79935c9cd58292afc00e99be862911aed23a75a3..b39f36016c95dfef090f4909adf353148cdf6fe9 100644 (file)
@@ -1859,7 +1859,7 @@ whisper_vocab::id whisper_sample_best(
     if (need_timestamp) {
         // at the end of the 30-second audio segment, we start giving preference to time tokens
         for (int i = 0; i < top_k; i++) {
-            if (probs_id[i].second > vocab.token_beg + 1300 && probs_id[i].first > probs_id[0].first*0.1) {
+            if (probs_id[i].second > vocab.token_beg + 1300 && probs_id[i].first > 0.01*probs_id[0].first) {
                 return probs_id[i].second;
             }
         }
@@ -1909,8 +1909,31 @@ whisper_vocab::id whisper_sample_timestamp(
     return probs_id[0].second;
 }
 
+// naive Discrete Fourier Transform
+// input is real-valued
+// output is complex-valued
+void dft(const std::vector<float> & in, std::vector<float> & out) {
+    int N = in.size();
+
+    out.resize(N*2);
+
+    for (int k = 0; k < N; k++) {
+        float re = 0;
+        float im = 0;
+
+        for (int n = 0; n < N; n++) {
+            float angle = 2*M_PI*k*n/N;
+            re += in[n]*cos(angle);
+            im -= in[n]*sin(angle);
+        }
+
+        out[k*2 + 0] = re;
+        out[k*2 + 1] = im;
+    }
+}
+
 // Cooley-Tukey FFT
-// poor man's implmentation - use something better
+// poor man's implementation - use something better
 // input is real-valued
 // output is complex-valued
 void fft(const std::vector<float> & in, std::vector<float> & out) {
@@ -1924,6 +1947,11 @@ void fft(const std::vector<float> & in, std::vector<float> & out) {
         return;
     }
 
+    if (N%2 == 1) {
+        dft(in, out);
+        return;
+    }
+
     std::vector<float> even;
     std::vector<float> odd;
 
@@ -2014,9 +2042,20 @@ bool log_mel_spectrogram(
                 // FFT -> mag^2
                 fft(fft_in, fft_out);
 
-                for (int j = 0; j < n_fft; j++) {
+                for (int j = 0; j < fft_size; j++) {
                     fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]);
                 }
+                for (int j = 1; j < fft_size/2; j++) {
+                    //if (i == 0) {
+                    //    printf("%d: %f %f\n", j, fft_out[j], fft_out[fft_size - j]);
+                    //}
+                    fft_out[j] += fft_out[fft_size - j];
+                }
+                if (i == 0) {
+                    //for (int j = 0; j < fft_size; j++) {
+                    //    printf("%d: %e\n", j, fft_out[j]);
+                    //}
+                }
 
                 // mel spectrogram
                 for (int j = 0; j < mel.n_mel; j++) {
@@ -2048,6 +2087,7 @@ bool log_mel_spectrogram(
             mmax = mel.data[i];
         }
     }
+    //printf("%s: max = %f\n", __func__, mmax);
 
     mmax -= 8.0;
 
@@ -2125,8 +2165,8 @@ int main(int argc, char ** argv) {
             return 2;
         }
 
-        if (wav.channels != 1) {
-            fprintf(stderr, "%s: WAV file '%s' must be mono\n", argv[0], params.fname_inp.c_str());
+        if (wav.channels != 1 && wav.channels != 2) {
+            fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], params.fname_inp.c_str());
             return 3;
         }
 
@@ -2140,15 +2180,23 @@ int main(int argc, char ** argv) {
             return 5;
         }
 
+        int n = wav.totalPCMFrameCount;
+
         std::vector<int16_t> pcm16;
-        pcm16.resize(wav.totalPCMFrameCount);
-        drwav_read_pcm_frames_s16(&wav, wav.totalPCMFrameCount, pcm16.data());
+        pcm16.resize(n*wav.channels);
+        drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
         drwav_uninit(&wav);
 
-        // convert to float
-        pcmf32.resize(pcm16.size());
-        for (size_t i = 0; i < pcm16.size(); i++) {
-            pcmf32[i] = float(pcm16[i])/32768.0f;
+        // convert to mono, float
+        pcmf32.resize(n);
+        if (wav.channels == 1) {
+            for (size_t i = 0; i < n; i++) {
+                pcmf32[i] = float(pcm16[i])/32768.0f;
+            }
+        } else {
+            for (size_t i = 0; i < n; i++) {
+                pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
+            }
         }
     }
 
@@ -2195,7 +2243,7 @@ int main(int argc, char ** argv) {
     }
 
     // the generated text including timestamps
-    std::vector<whisper_result> result_all;
+    //std::vector<whisper_result> result_all;
 
     // main loop
     int seek = 0;
@@ -2252,7 +2300,7 @@ int main(int argc, char ** argv) {
         int result_len = 0;
         std::vector<whisper_result> result_cur;
 
-        for (int i = 0; i < model.hparams.n_text_ctx/2; ++i) {
+        for (int i = 0; i < model.hparams.n_text_ctx/2 - 4; ++i) {
             // decode
             if (prompt.size() > 0) {
                 const int64_t t_start_us = ggml_time_us();
@@ -2317,7 +2365,7 @@ int main(int argc, char ** argv) {
         }
 
         result_cur.resize(result_len);
-        result_all.insert(result_all.end(), result_cur.begin(), result_cur.end());
+        //result_all.insert(result_all.end(), result_cur.begin(), result_cur.end());
 
         for (const auto & r : result_cur) {
             prompt_past.push_back(r.id);