]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
extra: Add benchmark script implemented in Python (#1298)
authorNeil Chudleigh <redacted>
Mon, 25 Sep 2023 15:45:15 +0000 (08:45 -0700)
committerGitHub <redacted>
Mon, 25 Sep 2023 15:45:15 +0000 (23:45 +0800)
* Create bench.py

* Various benchmark results

* Update benchmark script with hardware name, and file checks

* Remove old benchmark results

* Add git shorthash

* Round to 2 digits on calculated floats

* Fix the header reference when sorting results

* FIx order of models

* Parse file name

* Simplify filecheck

* Improve print run print statement

* Use simplified model name

* Update benchmark_results.csv

* Process single or lists of processors and threads

* Ignore benchmark results, dont check in

* Move bench.py to extra folder

* Readme section on how to use

* Move command to correct location

* Use separate list for models that exist

* Handle subprocess error in git short hash check

* Fix filtered models list initialization

.gitignore
README.md
extra/bench.py [new file with mode: 0644]

index b30a1d19f01dc6d46c7b82f06aee53c0e9941270..ab1fc2e3a7737449cfd00b12d81e5980aea313ae 100644 (file)
@@ -46,3 +46,5 @@ models/*.mlpackage
 bindings/java/.gradle/
 bindings/java/.idea/
 .idea/
+
+benchmark_results.csv
index 894e0e0338ecadb70f468a468c1434f9a42d2718..5831797f79800668ce1d269e4be1bbfea0f2bde9 100644 (file)
--- a/README.md
+++ b/README.md
@@ -709,6 +709,19 @@ took to execute it. The results are summarized in the following Github issue:
 
 [Benchmark results](https://github.com/ggerganov/whisper.cpp/issues/89)
 
+Additionally a script to run whisper.cpp with different models and audio files is provided [bench.py](bench.py).
+
+You can run it with the following command, by default it will run against any standard model in the models folder.
+
+```bash
+python3 extra/bench.py -f samples/jfk.wav -t 2,4,8 -p 1,2
+```
+
+It is written in python with the intention of being easy to modify and extend for your benchmarking use case.
+
+It outputs a csv file with the results of the benchmarking.
+
+
 ## ggml format
 
 The original models are converted to a custom binary format. This allows to pack everything needed into a single file:
diff --git a/extra/bench.py b/extra/bench.py
new file mode 100644 (file)
index 0000000..74956e7
--- /dev/null
@@ -0,0 +1,222 @@
+import os
+import subprocess
+import re
+import csv
+import wave
+import contextlib
+import argparse
+
+
+# Custom action to handle comma-separated list
+class ListAction(argparse.Action):
+    def __call__(self, parser, namespace, values, option_string=None):
+        setattr(namespace, self.dest, [int(val) for val in values.split(",")])
+
+
+parser = argparse.ArgumentParser(description="Benchmark the speech recognition model")
+
+# Define the argument to accept a list
+parser.add_argument(
+    "-t",
+    "--threads",
+    dest="threads",
+    action=ListAction,
+    default=[4],
+    help="List of thread counts to benchmark (comma-separated, default: 4)",
+)
+
+parser.add_argument(
+    "-p",
+    "--processors",
+    dest="processors",
+    action=ListAction,
+    default=[1],
+    help="List of processor counts to benchmark (comma-separated, default: 1)",
+)
+
+
+parser.add_argument(
+    "-f",
+    "--filename",
+    type=str,
+    default="./samples/jfk.wav",
+    help="Relative path of the file to transcribe (default: ./samples/jfk.wav)",
+)
+
+# Parse the command line arguments
+args = parser.parse_args()
+
+sample_file = args.filename
+
+threads = args.threads
+processors = args.processors
+
+# Define the models, threads, and processor counts to benchmark
+models = [
+    "ggml-tiny.en.bin",
+    "ggml-tiny.bin",
+    "ggml-base.en.bin",
+    "ggml-base.bin",
+    "ggml-small.en.bin",
+    "ggml-small.bin",
+    "ggml-medium.en.bin",
+    "ggml-medium.bin",
+    "ggml-large.bin",
+]
+
+
+metal_device = ""
+
+# Initialize a dictionary to hold the results
+results = {}
+
+gitHashHeader = "Commit"
+modelHeader = "Model"
+hardwareHeader = "Hardware"
+recordingLengthHeader = "Recording Length (seconds)"
+threadHeader = "Thread"
+processorCountHeader = "Processor Count"
+loadTimeHeader = "Load Time (ms)"
+sampleTimeHeader = "Sample Time (ms)"
+encodeTimeHeader = "Encode Time (ms)"
+decodeTimeHeader = "Decode Time (ms)"
+sampleTimePerRunHeader = "Sample Time per Run (ms)"
+encodeTimePerRunHeader = "Encode Time per Run (ms)"
+decodeTimePerRunHeader = "Decode Time per Run (ms)"
+totalTimeHeader = "Total Time (ms)"
+
+
+def check_file_exists(file: str) -> bool:
+    return os.path.isfile(file)
+
+
+def get_git_short_hash() -> str:
+    try:
+        return (
+            subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
+            .decode()
+            .strip()
+        )
+    except subprocess.CalledProcessError as e:
+        return ""
+
+
+def wav_file_length(file: str = sample_file) -> float:
+    with contextlib.closing(wave.open(file, "r")) as f:
+        frames = f.getnframes()
+        rate = f.getframerate()
+        duration = frames / float(rate)
+        return duration
+
+
+def extract_metrics(output: str, label: str) -> tuple[float, float]:
+    match = re.search(rf"{label} \s*=\s*(\d+\.\d+)\s*ms\s*/\s*(\d+)\s*runs", output)
+    time = float(match.group(1)) if match else None
+    runs = float(match.group(2)) if match else None
+    return time, runs
+
+
+def extract_device(output: str) -> str:
+    match = re.search(r"picking default device: (.*)", output)
+    device = match.group(1) if match else "Not found"
+    return device
+
+
+# Check if the sample file exists
+if not check_file_exists(sample_file):
+    raise FileNotFoundError(f"Sample file {sample_file} not found")
+
+recording_length = wav_file_length()
+
+
+# Check that all models exist
+# Filter out models from list that are not downloaded
+filtered_models = []
+for model in models:
+    if check_file_exists(f"models/{model}"):
+        filtered_models.append(model)
+    else:
+        print(f"Model {model} not found, removing from list")
+
+models = filtered_models
+
+# Loop over each combination of parameters
+for model in filtered_models:
+    for thread in threads:
+        for processor_count in processors:
+            # Construct the command to run
+            cmd = f"./main -m models/{model} -t {thread} -p {processor_count} -f {sample_file}"
+            # Run the command and get the output
+            process = subprocess.Popen(
+                cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
+            )
+
+            output = ""
+            while process.poll() is None:
+                output += process.stdout.read().decode()
+
+            # Parse the output
+            load_time_match = re.search(r"load time\s*=\s*(\d+\.\d+)\s*ms", output)
+            load_time = float(load_time_match.group(1)) if load_time_match else None
+
+            metal_device = extract_device(output)
+            sample_time, sample_runs = extract_metrics(output, "sample time")
+            encode_time, encode_runs = extract_metrics(output, "encode time")
+            decode_time, decode_runs = extract_metrics(output, "decode time")
+
+            total_time_match = re.search(r"total time\s*=\s*(\d+\.\d+)\s*ms", output)
+            total_time = float(total_time_match.group(1)) if total_time_match else None
+
+            model_name = model.replace("ggml-", "").replace(".bin", "")
+
+            print(
+                f"Ran model={model_name} threads={thread} processor_count={processor_count}, took {total_time}ms"
+            )
+            # Store the times in the results dictionary
+            results[(model_name, thread, processor_count)] = {
+                loadTimeHeader: load_time,
+                sampleTimeHeader: sample_time,
+                encodeTimeHeader: encode_time,
+                decodeTimeHeader: decode_time,
+                sampleTimePerRunHeader: round(sample_time / sample_runs, 2),
+                encodeTimePerRunHeader: round(encode_time / encode_runs, 2),
+                decodeTimePerRunHeader: round(decode_time / decode_runs, 2),
+                totalTimeHeader: total_time,
+            }
+
+# Write the results to a CSV file
+with open("benchmark_results.csv", "w", newline="") as csvfile:
+    fieldnames = [
+        gitHashHeader,
+        modelHeader,
+        hardwareHeader,
+        recordingLengthHeader,
+        threadHeader,
+        processorCountHeader,
+        loadTimeHeader,
+        sampleTimeHeader,
+        encodeTimeHeader,
+        decodeTimeHeader,
+        sampleTimePerRunHeader,
+        encodeTimePerRunHeader,
+        decodeTimePerRunHeader,
+        totalTimeHeader,
+    ]
+    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
+
+    writer.writeheader()
+
+    shortHash = get_git_short_hash()
+    # Sort the results by total time in ascending order
+    sorted_results = sorted(results.items(), key=lambda x: x[1].get(totalTimeHeader, 0))
+    for params, times in sorted_results:
+        row = {
+            gitHashHeader: shortHash,
+            modelHeader: params[0],
+            hardwareHeader: metal_device,
+            recordingLengthHeader: recording_length,
+            threadHeader: params[1],
+            processorCountHeader: params[2],
+        }
+        row.update(times)
+        writer.writerow(row)