]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
scripts : support arbitrary input file formats in compare-llama-bench.py (#13455)
authorSigbjørn Skjæret <redacted>
Tue, 13 May 2025 13:31:12 +0000 (15:31 +0200)
committerGitHub <redacted>
Tue, 13 May 2025 13:31:12 +0000 (15:31 +0200)
scripts/compare-llama-bench.py

index c32b449f7dc6b66a2ed4d022ed9f6fd695317154..fc93bf62a67fac877d7652fbbe1c6d7eb185639f 100755 (executable)
@@ -7,6 +7,10 @@ import sys
 import os
 from glob import glob
 import sqlite3
+import json
+import csv
+from typing import Optional, Union
+from collections.abc import Iterator, Sequence
 
 try:
     import git
@@ -17,6 +21,28 @@ except ImportError as e:
 
 logger = logging.getLogger("compare-llama-bench")
 
+# All llama-bench SQL fields
+DB_FIELDS = [
+    "build_commit", "build_number", "cpu_info",       "gpu_info",   "backends",     "model_filename",
+    "model_type",   "model_size",   "model_n_params", "n_batch",    "n_ubatch",     "n_threads",
+    "cpu_mask",     "cpu_strict",   "poll",           "type_k",     "type_v",       "n_gpu_layers",
+    "split_mode",   "main_gpu",     "no_kv_offload",  "flash_attn", "tensor_split", "tensor_buft_overrides",
+    "defrag_thold",
+    "use_mmap",     "embeddings",   "no_op_offload",  "n_prompt",   "n_gen",        "n_depth",
+    "test_time",    "avg_ns",       "stddev_ns",      "avg_ts",     "stddev_ts",
+]
+
+DB_TYPES = [
+    "TEXT",    "INTEGER", "TEXT",    "TEXT",    "TEXT",    "TEXT",
+    "TEXT",    "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER",
+    "TEXT",    "INTEGER", "INTEGER", "TEXT",    "TEXT",    "INTEGER",
+    "TEXT",    "INTEGER", "INTEGER", "INTEGER", "TEXT",    "TEXT",
+    "REAL",
+    "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER",
+    "TEXT",    "INTEGER", "INTEGER", "REAL",    "REAL",
+]
+assert len(DB_FIELDS) == len(DB_TYPES)
+
 # Properties by which to differentiate results per commit:
 KEY_PROPERTIES = [
     "cpu_info", "gpu_info", "backends", "n_gpu_layers", "tensor_buft_overrides", "model_filename", "model_type",
@@ -42,7 +68,7 @@ DEFAULT_HIDE = ["model_filename"]  # Always hide these properties by default.
 GPU_NAME_STRIP = ["NVIDIA GeForce ", "Tesla ", "AMD Radeon "]  # Strip prefixes for smaller tables.
 MODEL_SUFFIX_REPLACE = {" - Small": "_S", " - Medium": "_M", " - Large": "_L"}
 
-DESCRIPTION = """Creates tables from llama-bench data written to an SQLite database. Example usage (Linux):
+DESCRIPTION = """Creates tables from llama-bench data written to multiple JSON/CSV files, a single JSONL file or SQLite database. Example usage (Linux):
 
 $ git checkout master
 $ make clean && make llama-bench
@@ -70,12 +96,13 @@ help_c = (
 )
 parser.add_argument("-c", "--compare", help=help_c)
 help_i = (
-    "Input SQLite file for comparing commits. "
+    "JSON/JSONL/SQLite/CSV files for comparing commits. "
+    "Specify multiple times to use multiple input files (JSON/CSV only). "
     "Defaults to 'llama-bench.sqlite' in the current working directory. "
     "If no such file is found and there is exactly one .sqlite file in the current directory, "
     "that file is instead used as input."
 )
-parser.add_argument("-i", "--input", help=help_i)
+parser.add_argument("-i", "--input", action="append", help=help_i)
 help_o = (
     "Output format for the table. "
     "Defaults to 'pipe' (GitHub compatible). "
@@ -110,119 +137,321 @@ if unknown_args:
     sys.exit(1)
 
 input_file = known_args.input
-if input_file is None and os.path.exists("./llama-bench.sqlite"):
-    input_file = "llama-bench.sqlite"
-if input_file is None:
+if not input_file and os.path.exists("./llama-bench.sqlite"):
+    input_file = ["llama-bench.sqlite"]
+if not input_file:
     sqlite_files = glob("*.sqlite")
     if len(sqlite_files) == 1:
-        input_file = sqlite_files[0]
+        input_file = sqlite_files
 
-if input_file is None:
+if not input_file:
     logger.error("Cannot find a suitable input file, please provide one.\n")
     parser.print_help()
     sys.exit(1)
 
-connection = sqlite3.connect(input_file)
-cursor = connection.cursor()
 
-build_len_min: int = cursor.execute("SELECT MIN(LENGTH(build_commit)) from test;").fetchone()[0]
-build_len_max: int = cursor.execute("SELECT MAX(LENGTH(build_commit)) from test;").fetchone()[0]
+class LlamaBenchData:
+    repo: Optional[git.Repo]
+    build_len_min: int
+    build_len_max: int
+    build_len: int = 8
+    builds: list[str] = []
+    check_keys = set(KEY_PROPERTIES + ["build_commit", "test_time", "avg_ts"])
 
-if build_len_min != build_len_max:
-    logger.warning(f"{input_file} contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. "
-                   "Try purging the the database of old commits.")
-    cursor.execute(f"UPDATE test SET build_commit = SUBSTRING(build_commit, 1, {build_len_min});")
+    def __init__(self):
+        try:
+            self.repo = git.Repo(".", search_parent_directories=True)
+        except git.InvalidGitRepositoryError:
+            self.repo = None
 
-build_len: int = build_len_min
+    def _builds_init(self):
+        self.build_len = self.build_len_min
 
-builds = cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall()
-builds = list(map(lambda b: b[0], builds))  # list[tuple[str]] -> list[str]
+    def _check_keys(self, keys: set) -> Optional[set]:
+        """Private helper method that checks against required data keys and returns missing ones."""
+        if not keys >= self.check_keys:
+            return self.check_keys - keys
+        return None
 
-if not builds:
-    raise RuntimeError(f"{input_file} does not contain any builds.")
+    def find_parent_in_data(self, commit: git.Commit) -> Optional[str]:
+        """Helper method to find the most recent parent measured in number of commits for which there is data."""
+        heap: list[tuple[int, git.Commit]] = [(0, commit)]
+        seen_hexsha8 = set()
+        while heap:
+            depth, current_commit = heapq.heappop(heap)
+            current_hexsha8 = commit.hexsha[:self.build_len]
+            if current_hexsha8 in self.builds:
+                return current_hexsha8
+            for parent in commit.parents:
+                parent_hexsha8 = parent.hexsha[:self.build_len]
+                if parent_hexsha8 not in seen_hexsha8:
+                    seen_hexsha8.add(parent_hexsha8)
+                    heapq.heappush(heap, (depth + 1, parent))
+        return None
 
-try:
-    repo = git.Repo(".", search_parent_directories=True)
-except git.InvalidGitRepositoryError:
-    repo = None
-
-
-def find_parent_in_data(commit: git.Commit):
-    """Helper function to find the most recent parent measured in number of commits for which there is data."""
-    heap: list[tuple[int, git.Commit]] = [(0, commit)]
-    seen_hexsha8 = set()
-    while heap:
-        depth, current_commit = heapq.heappop(heap)
-        current_hexsha8 = commit.hexsha[:build_len]
-        if current_hexsha8 in builds:
-            return current_hexsha8
-        for parent in commit.parents:
-            parent_hexsha8 = parent.hexsha[:build_len]
-            if parent_hexsha8 not in seen_hexsha8:
-                seen_hexsha8.add(parent_hexsha8)
-                heapq.heappush(heap, (depth + 1, parent))
-    return None
-
-
-def get_all_parent_hexsha8s(commit: git.Commit):
-    """Helper function to recursively get hexsha8 values for all parents of a commit."""
-    unvisited = [commit]
-    visited   = []
-
-    while unvisited:
-        current_commit = unvisited.pop(0)
-        visited.append(current_commit.hexsha[:build_len])
-        for parent in current_commit.parents:
-            if parent.hexsha[:build_len] not in visited:
-                unvisited.append(parent)
-
-    return visited
-
-
-def get_commit_name(hexsha8: str):
-    """Helper function to find a human-readable name for a commit if possible."""
-    if repo is None:
+    def get_all_parent_hexsha8s(self, commit: git.Commit) -> Sequence[str]:
+        """Helper method to recursively get hexsha8 values for all parents of a commit."""
+        unvisited = [commit]
+        visited   = []
+
+        while unvisited:
+            current_commit = unvisited.pop(0)
+            visited.append(current_commit.hexsha[:self.build_len])
+            for parent in current_commit.parents:
+                if parent.hexsha[:self.build_len] not in visited:
+                    unvisited.append(parent)
+
+        return visited
+
+    def get_commit_name(self, hexsha8: str) -> str:
+        """Helper method to find a human-readable name for a commit if possible."""
+        if self.repo is None:
+            return hexsha8
+        for h in self.repo.heads:
+            if h.commit.hexsha[:self.build_len] == hexsha8:
+                return h.name
+        for t in self.repo.tags:
+            if t.commit.hexsha[:self.build_len] == hexsha8:
+                return t.name
         return hexsha8
-    for h in repo.heads:
-        if h.commit.hexsha[:build_len] == hexsha8:
-            return h.name
-    for t in repo.tags:
-        if t.commit.hexsha[:build_len] == hexsha8:
-            return t.name
-    return hexsha8
-
-
-def get_commit_hexsha8(name: str):
-    """Helper function to search for a commit given a human-readable name."""
-    if repo is None:
+
+    def get_commit_hexsha8(self, name: str) -> Optional[str]:
+        """Helper method to search for a commit given a human-readable name."""
+        if self.repo is None:
+            return None
+        for h in self.repo.heads:
+            if h.name == name:
+                return h.commit.hexsha[:self.build_len]
+        for t in self.repo.tags:
+            if t.name == name:
+                return t.commit.hexsha[:self.build_len]
+        for c in self.repo.iter_commits("--all"):
+            if c.hexsha[:self.build_len] == name[:self.build_len]:
+                return c.hexsha[:self.build_len]
         return None
-    for h in repo.heads:
-        if h.name == name:
-            return h.commit.hexsha[:build_len]
-    for t in repo.tags:
-        if t.name == name:
-            return t.commit.hexsha[:build_len]
-    for c in repo.iter_commits("--all"):
-        if c.hexsha[:build_len] == name[:build_len]:
-            return c.hexsha[:build_len]
-    return None
+
+    def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequence[tuple]]:
+        """Helper method that gets rows of (build_commit, test_time) sorted by the latter."""
+        return []
+
+    def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]:
+        """
+        Helper method that gets table rows for some list of properties.
+        Rows are created by combining those where all provided properties are equal.
+        The resulting rows are then grouped by the provided properties and the t/s values are averaged.
+        The returned rows are unique in terms of property combinations.
+        """
+        return []
+
+
+class LlamaBenchDataSQLite3(LlamaBenchData):
+    connection: sqlite3.Connection
+    cursor: sqlite3.Cursor
+
+    def __init__(self):
+        super().__init__()
+        self.connection = sqlite3.connect(":memory:")
+        self.cursor = self.connection.cursor()
+        self.cursor.execute(f"CREATE TABLE test({', '.join(' '.join(x) for x in zip(DB_FIELDS, DB_TYPES))});")
+
+    def _builds_init(self):
+        if self.connection:
+            self.build_len_min = self.cursor.execute("SELECT MIN(LENGTH(build_commit)) from test;").fetchone()[0]
+            self.build_len_max = self.cursor.execute("SELECT MAX(LENGTH(build_commit)) from test;").fetchone()[0]
+
+            if self.build_len_min != self.build_len_max:
+                logger.warning("Data contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. "
+                               "Try purging the the database of old commits.")
+                self.cursor.execute(f"UPDATE test SET build_commit = SUBSTRING(build_commit, 1, {self.build_len_min});")
+
+            builds = self.cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall()
+            self.builds = list(map(lambda b: b[0], builds))  # list[tuple[str]] -> list[str]
+        super()._builds_init()
+
+    def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequence[tuple]]:
+        data = self.cursor.execute(
+            "SELECT build_commit, test_time FROM test ORDER BY test_time;").fetchall()
+        return reversed(data) if reverse else data
+
+    def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]:
+        select_string = ", ".join(
+            [f"tb.{p}" for p in properties] + ["tb.n_prompt", "tb.n_gen", "tb.n_depth", "AVG(tb.avg_ts)", "AVG(tc.avg_ts)"])
+        equal_string = " AND ".join(
+            [f"tb.{p} = tc.{p}" for p in KEY_PROPERTIES] + [
+                f"tb.build_commit = '{hexsha8_baseline}'", f"tc.build_commit = '{hexsha8_compare}'"]
+        )
+        group_order_string = ", ".join([f"tb.{p}" for p in properties] + ["tb.n_gen", "tb.n_prompt", "tb.n_depth"])
+        query = (f"SELECT {select_string} FROM test tb JOIN test tc ON {equal_string} "
+                 f"GROUP BY {group_order_string} ORDER BY {group_order_string};")
+        return self.cursor.execute(query).fetchall()
+
+
+class LlamaBenchDataSQLite3File(LlamaBenchDataSQLite3):
+    def __init__(self, data_file: str):
+        super().__init__()
+
+        self.connection.close()
+        self.connection = sqlite3.connect(data_file)
+        self.cursor = self.connection.cursor()
+        self._builds_init()
+
+    @staticmethod
+    def valid_format(data_file: str) -> bool:
+        connection = sqlite3.connect(data_file)
+        cursor = connection.cursor()
+
+        try:
+            if cursor.execute("PRAGMA schema_version;").fetchone()[0] == 0:
+                raise sqlite3.DatabaseError("The provided input file does not exist or is empty.")
+        except sqlite3.DatabaseError as e:
+            logger.debug(f'"{data_file}" is not a valid SQLite3 file.', exc_info=e)
+            cursor = None
+
+        connection.close()
+        return True if cursor else False
+
+
+class LlamaBenchDataJSONL(LlamaBenchDataSQLite3):
+    def __init__(self, data_file: str):
+        super().__init__()
+
+        with open(data_file, "r", encoding="utf-8") as fp:
+            for i, line in enumerate(fp):
+                parsed = json.loads(line)
+
+                for k in parsed.keys() - set(DB_FIELDS):
+                    del parsed[k]
+
+                if (missing_keys := self._check_keys(parsed.keys())):
+                    raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}")
+
+                self.cursor.execute(f"INSERT INTO test({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", tuple(parsed.values()))
+
+        self._builds_init()
+
+    @staticmethod
+    def valid_format(data_file: str) -> bool:
+        try:
+            with open(data_file, "r", encoding="utf-8") as fp:
+                for line in fp:
+                    json.loads(line)
+                    break
+        except Exception as e:
+            logger.debug(f'"{data_file}" is not a valid JSONL file.', exc_info=e)
+            return False
+
+        return True
+
+
+class LlamaBenchDataJSON(LlamaBenchDataSQLite3):
+    def __init__(self, data_files: list[str]):
+        super().__init__()
+
+        for data_file in data_files:
+            with open(data_file, "r", encoding="utf-8") as fp:
+                parsed = json.load(fp)
+
+                for i, entry in enumerate(parsed):
+                    for k in entry.keys() - set(DB_FIELDS):
+                        del entry[k]
+
+                    if (missing_keys := self._check_keys(entry.keys())):
+                        raise RuntimeError(f"Missing required data key(s) at entry {i + 1}: {', '.join(missing_keys)}")
+
+                    self.cursor.execute(f"INSERT INTO test({', '.join(entry.keys())}) VALUES({', '.join('?' * len(entry))});", tuple(entry.values()))
+
+        self._builds_init()
+
+    @staticmethod
+    def valid_format(data_files: list[str]) -> bool:
+        if not data_files:
+            return False
+
+        for data_file in data_files:
+            try:
+                with open(data_file, "r", encoding="utf-8") as fp:
+                    json.load(fp)
+            except Exception as e:
+                logger.debug(f'"{data_file}" is not a valid JSON file.', exc_info=e)
+                return False
+
+        return True
+
+
+class LlamaBenchDataCSV(LlamaBenchDataSQLite3):
+    def __init__(self, data_files: list[str]):
+        super().__init__()
+
+        for data_file in data_files:
+            with open(data_file, "r", encoding="utf-8") as fp:
+                for i, parsed in enumerate(csv.DictReader(fp)):
+                    keys = set(parsed.keys())
+
+                    for k in keys - set(DB_FIELDS):
+                        del parsed[k]
+
+                    if (missing_keys := self._check_keys(keys)):
+                        raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}")
+
+                    self.cursor.execute(f"INSERT INTO test({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", tuple(parsed.values()))
+
+        self._builds_init()
+
+    @staticmethod
+    def valid_format(data_files: list[str]) -> bool:
+        if not data_files:
+            return False
+
+        for data_file in data_files:
+            try:
+                with open(data_file, "r", encoding="utf-8") as fp:
+                    for parsed in csv.DictReader(fp):
+                        break
+            except Exception as e:
+                logger.debug(f'"{data_file}" is not a valid CSV file.', exc_info=e)
+                return False
+
+        return True
+
+
+bench_data = None
+if len(input_file) == 1:
+    if LlamaBenchDataSQLite3File.valid_format(input_file[0]):
+        bench_data = LlamaBenchDataSQLite3File(input_file[0])
+    elif LlamaBenchDataJSON.valid_format(input_file):
+        bench_data = LlamaBenchDataJSON(input_file)
+    elif LlamaBenchDataJSONL.valid_format(input_file[0]):
+        bench_data = LlamaBenchDataJSONL(input_file[0])
+    elif LlamaBenchDataCSV.valid_format(input_file):
+        bench_data = LlamaBenchDataCSV(input_file)
+else:
+    if LlamaBenchDataJSON.valid_format(input_file):
+        bench_data = LlamaBenchDataJSON(input_file)
+    elif LlamaBenchDataCSV.valid_format(input_file):
+        bench_data = LlamaBenchDataCSV(input_file)
+
+if not bench_data:
+    raise RuntimeError("No valid (or some invalid) input files found.")
+
+if not bench_data.builds:
+    raise RuntimeError(f"{input_file} does not contain any builds.")
 
 
 hexsha8_baseline = name_baseline = None
 
 # If the user specified a baseline, try to find a commit for it:
 if known_args.baseline is not None:
-    if known_args.baseline in builds:
+    if known_args.baseline in bench_data.builds:
         hexsha8_baseline = known_args.baseline
     if hexsha8_baseline is None:
-        hexsha8_baseline = get_commit_hexsha8(known_args.baseline)
+        hexsha8_baseline = bench_data.get_commit_hexsha8(known_args.baseline)
         name_baseline = known_args.baseline
     if hexsha8_baseline is None:
         logger.error(f"cannot find data for baseline={known_args.baseline}.")
         sys.exit(1)
 # Otherwise, search for the most recent parent of master for which there is data:
-elif repo is not None:
-    hexsha8_baseline = find_parent_in_data(repo.heads.master.commit)
+elif bench_data.repo is not None:
+    hexsha8_baseline = bench_data.find_parent_in_data(bench_data.repo.heads.master.commit)
 
     if hexsha8_baseline is None:
         logger.error("No baseline was provided and did not find data for any master branch commits.\n")
@@ -235,27 +464,25 @@ else:
     sys.exit(1)
 
 
-name_baseline = get_commit_name(hexsha8_baseline)
+name_baseline = bench_data.get_commit_name(hexsha8_baseline)
 
 hexsha8_compare = name_compare = None
 
 # If the user has specified a compare value, try to find a corresponding commit:
 if known_args.compare is not None:
-    if known_args.compare in builds:
+    if known_args.compare in bench_data.builds:
         hexsha8_compare = known_args.compare
     if hexsha8_compare is None:
-        hexsha8_compare = get_commit_hexsha8(known_args.compare)
+        hexsha8_compare = bench_data.get_commit_hexsha8(known_args.compare)
         name_compare = known_args.compare
     if hexsha8_compare is None:
         logger.error(f"cannot find data for compare={known_args.compare}.")
         sys.exit(1)
 # Otherwise, search for the commit for llama-bench was most recently run
 # and that is not a parent of master:
-elif repo is not None:
-    hexsha8s_master = get_all_parent_hexsha8s(repo.heads.master.commit)
-    builds_timestamp = cursor.execute(
-        "SELECT build_commit, test_time FROM test ORDER BY test_time;").fetchall()
-    for (hexsha8, _) in reversed(builds_timestamp):
+elif bench_data.repo is not None:
+    hexsha8s_master = bench_data.get_all_parent_hexsha8s(bench_data.repo.heads.master.commit)
+    for (hexsha8, _) in bench_data.builds_timestamp(reverse=True):
         if hexsha8 not in hexsha8s_master:
             hexsha8_compare = hexsha8
             break
@@ -270,26 +497,7 @@ else:
     parser.print_help()
     sys.exit(1)
 
-name_compare = get_commit_name(hexsha8_compare)
-
-
-def get_rows(properties):
-    """
-    Helper function that gets table rows for some list of properties.
-    Rows are created by combining those where all provided properties are equal.
-    The resulting rows are then grouped by the provided properties and the t/s values are averaged.
-    The returned rows are unique in terms of property combinations.
-    """
-    select_string = ", ".join(
-        [f"tb.{p}" for p in properties] + ["tb.n_prompt", "tb.n_gen", "tb.n_depth", "AVG(tb.avg_ts)", "AVG(tc.avg_ts)"])
-    equal_string = " AND ".join(
-        [f"tb.{p} = tc.{p}" for p in KEY_PROPERTIES] + [
-            f"tb.build_commit = '{hexsha8_baseline}'", f"tc.build_commit = '{hexsha8_compare}'"]
-    )
-    group_order_string = ", ".join([f"tb.{p}" for p in properties] + ["tb.n_gen", "tb.n_prompt", "tb.n_depth"])
-    query = (f"SELECT {select_string} FROM test tb JOIN test tc ON {equal_string} "
-             f"GROUP BY {group_order_string} ORDER BY {group_order_string};")
-    return cursor.execute(query).fetchall()
+name_compare = bench_data.get_commit_name(hexsha8_compare)
 
 
 # If the user provided columns to group the results by, use them:
@@ -303,10 +511,10 @@ if known_args.show is not None:
         logger.error(f"Unknown values for --show: {', '.join(unknown_cols)}")
         parser.print_usage()
         sys.exit(1)
-    rows_show = get_rows(show)
+    rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare)
 # Otherwise, select those columns where the values are not all the same:
 else:
-    rows_full = get_rows(KEY_PROPERTIES)
+    rows_full = bench_data.get_rows(KEY_PROPERTIES, hexsha8_baseline, hexsha8_compare)
     properties_different = []
     for i, kp_i in enumerate(KEY_PROPERTIES):
         if kp_i in DEFAULT_SHOW or kp_i in ["n_prompt", "n_gen", "n_depth"]:
@@ -336,7 +544,7 @@ else:
             show.remove(prop)
         except ValueError:
             pass
-    rows_show = get_rows(show)
+    rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare)
 
 if not rows_show:
     logger.error(f"No comparable data was found between {name_baseline} and {name_compare}.\n")