]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
sync : whisper.cpp
authorGeorgi Gerganov <redacted>
Tue, 18 Oct 2022 16:12:07 +0000 (19:12 +0300)
committerGeorgi Gerganov <redacted>
Tue, 18 Oct 2022 16:12:26 +0000 (19:12 +0300)
- Add MSVC header
- FP16 GELU
- C interface fixes (no unions)
- Minor CMake updates

CMakeLists.txt
examples/whisper/main.cpp
examples/whisper/whisper.cpp
examples/whisper/whisper.h
src/ggml.c
src/msvc_thread_atomic.h [new file with mode: 0644]

index 73e174c58410d4a666f32507e77df67f7eef7dc1..d2f95ccef73df36c787f4226bc728d3c42e3694e 100644 (file)
@@ -15,7 +15,7 @@ endif()
 
 # options
 
-option(GGML_ALL_WARNINGS            "ggml: enable all compiler warnings" ON)
+option(GGML_ALL_WARNINGS            "ggml: enable all compiler warnings"                   ON)
 option(GGML_ALL_WARNINGS_3RD_PARTY  "ggml: enable all compiler warnings in 3rd party libs" OFF)
 
 option(GGML_SANITIZE_THREAD         "ggml: enable thread sanitizer"    OFF)
@@ -25,7 +25,7 @@ option(GGML_SANITIZE_UNDEFINED      "ggml: enable undefined sanitizer" OFF)
 option(GGML_BUILD_TESTS             "ggml: build tests"    ${GGML_STANDALONE})
 option(GGML_BUILD_EXAMPLES          "ggml: build examples" ${GGML_STANDALONE})
 
-option(GGML_PERF                    "ggml: enable perf timings" ${GGML_PERF})
+option(GGML_PERF                    "ggml: enable perf timings"          OFF)
 option(GGML_NO_ACCELERATE           "ggml: disable Accelerate framework" OFF)
 
 # sanitizers
index b913522ed77d51dc81eef4a7dbb7572370b046ae..995eefc18e73a8c44301033cbbf3f5013ef4d242 100644 (file)
@@ -216,7 +216,7 @@ int main(int argc, char ** argv) {
 
         // run the inference
         {
-            whisper_full_params wparams = whisper_full_default_params(WHISPER_DECODE_GREEDY);
+            whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
 
             wparams.print_realtime       = true;
             wparams.print_progress       = false;
index 988527811c3cda0c253fc28908f95db7159f8f0e..236fcf1dba26d37eb81b6a2485f118270b99f36e 100644 (file)
@@ -2256,51 +2256,63 @@ void whisper_print_timings(struct whisper_context * ctx) {
 
 ////////////////////////////////////////////////////////////////////////////
 
-struct whisper_full_params whisper_full_default_params(enum whisper_decode_strategy strategy) {
+struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
     struct whisper_full_params result;
 
     switch (strategy) {
-        case WHISPER_DECODE_GREEDY:
+        case WHISPER_SAMPLING_GREEDY:
             {
                 result = {
-                    .strategy  = WHISPER_DECODE_GREEDY,
-                    .n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()),
-                    .offset_ms = 0,
+                    /*.strategy             =*/ WHISPER_SAMPLING_GREEDY,
 
-                    .translate            = false,
-                    .no_context           = false,
-                    .print_special_tokens = false,
-                    .print_progress       = true,
-                    .print_realtime       = false,
-                    .print_timestamps     = true,
+                    /*.n_threads            =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
+                    /*.offset_ms            =*/ 0,
 
-                    .language = "en",
+                    /*.translate            =*/ false,
+                    /*.no_context           =*/ false,
+                    /*.print_special_tokens =*/ false,
+                    /*.print_progress       =*/ true,
+                    /*.print_realtime       =*/ false,
+                    /*.print_timestamps     =*/ true,
 
-                    .greedy = {
-                        .n_past = 0,
+                    /*.language             =*/ "en",
+
+                    /*.greedy               =*/ {
+                        /*.n_past =*/ 0,
+                    },
+
+                    /*.beam_search          =*/ {
+                        /*.n_past     =*/ -1,
+                        /*.beam_width =*/ -1,
+                        /*.n_best     =*/ -1,
                     },
                 };
             } break;
-        case WHISPER_DECODE_BEAM_SEARCH:
+        case WHISPER_SAMPLING_BEAM_SEARCH:
             {
                 result = {
-                    .strategy  = WHISPER_DECODE_GREEDY,
-                    .n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()),
-                    .offset_ms = 0,
-
-                    .translate            = false,
-                    .no_context           = false,
-                    .print_special_tokens = false,
-                    .print_progress       = true,
-                    .print_realtime       = false,
-                    .print_timestamps     = true,
-
-                    .language = "en",
-
-                    .beam_search = {
-                        .n_past = 0,
-                        .beam_width = 10,
-                        .n_best = 5,
+                    /*.strategy             =*/ WHISPER_SAMPLING_BEAM_SEARCH,
+
+                    /*.n_threads            =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
+                    /*.offset_ms            =*/ 0,
+
+                    /*.translate            =*/ false,
+                    /*.no_context           =*/ false,
+                    /*.print_special_tokens =*/ false,
+                    /*.print_progress       =*/ true,
+                    /*.print_realtime       =*/ false,
+                    /*.print_timestamps     =*/ true,
+
+                    /*.language             =*/ "en",
+
+                    /*.greedy               =*/ {
+                        /*.n_past =*/ -1,
+                    },
+
+                    /*.beam_search          =*/ {
+                        /*.n_past     =*/ 0,
+                        /*.beam_width =*/ 10,
+                        /*.n_best     =*/ 5,
                     },
                 };
             } break;
@@ -2425,7 +2437,7 @@ int whisper_full(
                 whisper_token id  = 0;
                 whisper_token tid = whisper_token_beg(ctx);
 
-                id = whisper_sample_best(ctx, result_len == 0 || i > 32);
+                id = whisper_sample_best(ctx, result_len == 0);
                 if (i > 0) {
                     tid = whisper_sample_timestamp(ctx);
                 }
@@ -2445,9 +2457,12 @@ int whisper_full(
                 // end of text token
                 if (id == whisper_token_eot(ctx)) {
                     if (result_len == 0) {
-                        // TODO: figure out how to resolve this
-                        fprintf(stderr, "\n%s: failed to generate timestamp token - this should not happen\n\n", __func__);
-                        //result_len = i + 1;
+                        if (seek + seek_delta + 100 >= whisper_n_len(ctx)) {
+                            result_len = i + 1;
+                        } else {
+                            // TODO: figure out how to resolve this
+                            fprintf(stderr, "\n%s: failed to generate timestamp token - this should not happen\n\n", __func__);
+                        }
                     }
                     break;
                 }
index 381afd71d6e7495d26cf49dc70169cf2bea730e5..45faa5b2f220a2d662ff643f38f4afd53e97f546 100644 (file)
@@ -31,7 +31,8 @@ extern "C" {
     //
     // C interface
     //
-
+    // The following interface is thread-safe as long as the sample whisper_context is not used by multiple threads
+    // concurrently.
     //
     // Basic usage:
     //
@@ -153,14 +154,14 @@ extern "C" {
 
     ////////////////////////////////////////////////////////////////////////////
 
-    // Available decoding strategies
-    enum whisper_decode_strategy {
-        WHISPER_DECODE_GREEDY,      // Always select the most probable token
-        WHISPER_DECODE_BEAM_SEARCH, // TODO: not implemented yet!
+    // Available sampling strategies
+    enum whisper_sampling_strategy {
+        WHISPER_SAMPLING_GREEDY,      // Always select the most probable token
+        WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet!
     };
 
     struct whisper_full_params {
-        enum whisper_decode_strategy strategy;
+        enum whisper_sampling_strategy strategy;
 
         int n_threads;
         int offset_ms;
@@ -174,20 +175,18 @@ extern "C" {
 
         const char * language;
 
-        union {
-            struct {
-                int n_past;
-            } greedy;
-
-            struct {
-                int n_past;
-                int beam_width;
-                int n_best;
-            } beam_search;
-        };
+        struct {
+            int n_past;
+        } greedy;
+
+        struct {
+            int n_past;
+            int beam_width;
+            int n_best;
+        } beam_search;
     };
 
-    WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_decode_strategy strategy);
+    WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
 
     // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
     // Uses the specified decoding strategy to obtain the text.
index 7f11c96faa58eacd466719de7f22cd58df947d28..4861f24925396e85661b3c0b2b8515a49abb6f74 100644 (file)
@@ -14,7 +14,6 @@
 #include <stdint.h>
 #include <stdio.h>
 
-
 #if defined _MSC_VER
 #include "msvc_thread_atomic.h"
 #else
@@ -24,6 +23,7 @@ typedef void* thread_ret_t;
 #endif
 
 #define GGML_DEBUG 0
+#define GGML_GELU_FP16
 
 #if UINTPTR_MAX == 0xFFFFFFFF
     #define GGML_MEM_ALIGN 4
@@ -723,20 +723,22 @@ inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp
     }
 }
 
+#ifdef GGML_GELU_FP16
 inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
     uint16_t t;
     for (int i = 0; i < n; ++i) {
         ggml_fp16_t fp16 = ggml_fp32_to_fp16(x[i]);
         memcpy(&t, &fp16, sizeof(uint16_t));
-        y[i] = table_gelu_f16[t];
+        y[i] = ggml_fp16_to_fp32(table_gelu_f16[t]);
     }
 }
-
-//inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
-//    for (int i = 0; i < n; ++i) {
-//        y[i] = ggml_gelu_f32(x[i]);
-//    }
-//}
+#else
+inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = ggml_gelu_f32(x[i]);
+    }
+}
+#endif
 
 inline static void ggml_vec_sum_f32     (const int n, float * s, const float * x) { ggml_float sum = 0.0; for (int i = 0; i < n; ++i) sum += x[i]; *s += sum; }
 inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { ggml_vec_norm_f32(n, s, x); *s = 1./(*s); }
diff --git a/src/msvc_thread_atomic.h b/src/msvc_thread_atomic.h
new file mode 100644 (file)
index 0000000..52cd419
--- /dev/null
@@ -0,0 +1,31 @@
+#pragma once
+#include <Windows.h>
+
+typedef volatile LONG atomic_int;
+typedef atomic_int atomic_bool;
+
+static void atomic_store(atomic_int* ptr, LONG val) {
+    InterlockedExchange(ptr, val);
+}
+static LONG atomic_load(atomic_int* ptr) {
+    return InterlockedCompareExchange(ptr, 0, 0);
+}
+static LONG atomic_fetch_add(atomic_int* ptr, LONG inc) {
+    return InterlockedExchangeAdd(ptr, inc);
+}
+static LONG atomic_fetch_sub(atomic_int* ptr, LONG dec) {
+    return atomic_fetch_add(ptr, -(dec));
+}
+
+typedef HANDLE pthread_t;
+
+typedef DWORD thread_ret_t;
+static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
+    out = CreateThread(NULL, 0, func, arg, 0, NULL);
+    return out != NULL;
+}
+
+static int pthread_join(pthread_t thread, void* unused) {
+    return (int) WaitForSingleObject(thread, INFINITE);
+}
+