--- /dev/null
+import os
+import subprocess
+import re
+import csv
+import wave
+import contextlib
+import argparse
+
+
+# Custom action to handle comma-separated list
+class ListAction(argparse.Action):
+ def __call__(self, parser, namespace, values, option_string=None):
+ setattr(namespace, self.dest, [int(val) for val in values.split(",")])
+
+
+parser = argparse.ArgumentParser(description="Benchmark the speech recognition model")
+
+# Define the argument to accept a list
+parser.add_argument(
+ "-t",
+ "--threads",
+ dest="threads",
+ action=ListAction,
+ default=[4],
+ help="List of thread counts to benchmark (comma-separated, default: 4)",
+)
+
+parser.add_argument(
+ "-p",
+ "--processors",
+ dest="processors",
+ action=ListAction,
+ default=[1],
+ help="List of processor counts to benchmark (comma-separated, default: 1)",
+)
+
+
+parser.add_argument(
+ "-f",
+ "--filename",
+ type=str,
+ default="./samples/jfk.wav",
+ help="Relative path of the file to transcribe (default: ./samples/jfk.wav)",
+)
+
+# Parse the command line arguments
+args = parser.parse_args()
+
+sample_file = args.filename
+
+threads = args.threads
+processors = args.processors
+
+# Define the models, threads, and processor counts to benchmark
+models = [
+ "ggml-tiny.en.bin",
+ "ggml-tiny.bin",
+ "ggml-base.en.bin",
+ "ggml-base.bin",
+ "ggml-small.en.bin",
+ "ggml-small.bin",
+ "ggml-medium.en.bin",
+ "ggml-medium.bin",
+ "ggml-large.bin",
+]
+
+
+metal_device = ""
+
+# Initialize a dictionary to hold the results
+results = {}
+
+gitHashHeader = "Commit"
+modelHeader = "Model"
+hardwareHeader = "Hardware"
+recordingLengthHeader = "Recording Length (seconds)"
+threadHeader = "Thread"
+processorCountHeader = "Processor Count"
+loadTimeHeader = "Load Time (ms)"
+sampleTimeHeader = "Sample Time (ms)"
+encodeTimeHeader = "Encode Time (ms)"
+decodeTimeHeader = "Decode Time (ms)"
+sampleTimePerRunHeader = "Sample Time per Run (ms)"
+encodeTimePerRunHeader = "Encode Time per Run (ms)"
+decodeTimePerRunHeader = "Decode Time per Run (ms)"
+totalTimeHeader = "Total Time (ms)"
+
+
+def check_file_exists(file: str) -> bool:
+ return os.path.isfile(file)
+
+
+def get_git_short_hash() -> str:
+ try:
+ return (
+ subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
+ .decode()
+ .strip()
+ )
+ except subprocess.CalledProcessError as e:
+ return ""
+
+
+def wav_file_length(file: str = sample_file) -> float:
+ with contextlib.closing(wave.open(file, "r")) as f:
+ frames = f.getnframes()
+ rate = f.getframerate()
+ duration = frames / float(rate)
+ return duration
+
+
+def extract_metrics(output: str, label: str) -> tuple[float, float]:
+ match = re.search(rf"{label} \s*=\s*(\d+\.\d+)\s*ms\s*/\s*(\d+)\s*runs", output)
+ time = float(match.group(1)) if match else None
+ runs = float(match.group(2)) if match else None
+ return time, runs
+
+
+def extract_device(output: str) -> str:
+ match = re.search(r"picking default device: (.*)", output)
+ device = match.group(1) if match else "Not found"
+ return device
+
+
+# Check if the sample file exists
+if not check_file_exists(sample_file):
+ raise FileNotFoundError(f"Sample file {sample_file} not found")
+
+recording_length = wav_file_length()
+
+
+# Check that all models exist
+# Filter out models from list that are not downloaded
+filtered_models = []
+for model in models:
+ if check_file_exists(f"models/{model}"):
+ filtered_models.append(model)
+ else:
+ print(f"Model {model} not found, removing from list")
+
+models = filtered_models
+
+# Loop over each combination of parameters
+for model in filtered_models:
+ for thread in threads:
+ for processor_count in processors:
+ # Construct the command to run
+ cmd = f"./main -m models/{model} -t {thread} -p {processor_count} -f {sample_file}"
+ # Run the command and get the output
+ process = subprocess.Popen(
+ cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
+ )
+
+ output = ""
+ while process.poll() is None:
+ output += process.stdout.read().decode()
+
+ # Parse the output
+ load_time_match = re.search(r"load time\s*=\s*(\d+\.\d+)\s*ms", output)
+ load_time = float(load_time_match.group(1)) if load_time_match else None
+
+ metal_device = extract_device(output)
+ sample_time, sample_runs = extract_metrics(output, "sample time")
+ encode_time, encode_runs = extract_metrics(output, "encode time")
+ decode_time, decode_runs = extract_metrics(output, "decode time")
+
+ total_time_match = re.search(r"total time\s*=\s*(\d+\.\d+)\s*ms", output)
+ total_time = float(total_time_match.group(1)) if total_time_match else None
+
+ model_name = model.replace("ggml-", "").replace(".bin", "")
+
+ print(
+ f"Ran model={model_name} threads={thread} processor_count={processor_count}, took {total_time}ms"
+ )
+ # Store the times in the results dictionary
+ results[(model_name, thread, processor_count)] = {
+ loadTimeHeader: load_time,
+ sampleTimeHeader: sample_time,
+ encodeTimeHeader: encode_time,
+ decodeTimeHeader: decode_time,
+ sampleTimePerRunHeader: round(sample_time / sample_runs, 2),
+ encodeTimePerRunHeader: round(encode_time / encode_runs, 2),
+ decodeTimePerRunHeader: round(decode_time / decode_runs, 2),
+ totalTimeHeader: total_time,
+ }
+
+# Write the results to a CSV file
+with open("benchmark_results.csv", "w", newline="") as csvfile:
+ fieldnames = [
+ gitHashHeader,
+ modelHeader,
+ hardwareHeader,
+ recordingLengthHeader,
+ threadHeader,
+ processorCountHeader,
+ loadTimeHeader,
+ sampleTimeHeader,
+ encodeTimeHeader,
+ decodeTimeHeader,
+ sampleTimePerRunHeader,
+ encodeTimePerRunHeader,
+ decodeTimePerRunHeader,
+ totalTimeHeader,
+ ]
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
+
+ writer.writeheader()
+
+ shortHash = get_git_short_hash()
+ # Sort the results by total time in ascending order
+ sorted_results = sorted(results.items(), key=lambda x: x[1].get(totalTimeHeader, 0))
+ for params, times in sorted_results:
+ row = {
+ gitHashHeader: shortHash,
+ modelHeader: params[0],
+ hardwareHeader: metal_device,
+ recordingLengthHeader: recording_length,
+ threadHeader: params[1],
+ processorCountHeader: params[2],
+ }
+ row.update(times)
+ writer.writerow(row)