]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
scripts: fix crash when --tool is not set (#15133)
authorJohannes Gäßler <redacted>
Thu, 7 Aug 2025 06:50:30 +0000 (08:50 +0200)
committerGitHub <redacted>
Thu, 7 Aug 2025 06:50:30 +0000 (08:50 +0200)
scripts/compare-llama-bench.py

index c974d83b578288bace3918f5baffb4dcf5912929..8366f89a08076683f8a518c7a660d19a399cb5fa 100755 (executable)
@@ -315,28 +315,29 @@ class LlamaBenchData:
 
 
 class LlamaBenchDataSQLite3(LlamaBenchData):
-    connection: sqlite3.Connection
+    connection: Optional[sqlite3.Connection] = None
     cursor: sqlite3.Cursor
     table_name: str
 
     def __init__(self, tool: str = "llama-bench"):
         super().__init__(tool)
-        self.connection = sqlite3.connect(":memory:")
-        self.cursor = self.connection.cursor()
+        if self.connection is None:
+            self.connection = sqlite3.connect(":memory:")
+            self.cursor = self.connection.cursor()
 
-        # Set table name and schema based on tool
-        if self.tool == "llama-bench":
-            self.table_name = "llama_bench"
-            db_fields = LLAMA_BENCH_DB_FIELDS
-            db_types = LLAMA_BENCH_DB_TYPES
-        elif self.tool == "test-backend-ops":
-            self.table_name = "test_backend_ops"
-            db_fields = TEST_BACKEND_OPS_DB_FIELDS
-            db_types = TEST_BACKEND_OPS_DB_TYPES
-        else:
-            assert False
+            # Set table name and schema based on tool
+            if self.tool == "llama-bench":
+                self.table_name = "llama_bench"
+                db_fields = LLAMA_BENCH_DB_FIELDS
+                db_types = LLAMA_BENCH_DB_TYPES
+            elif self.tool == "test-backend-ops":
+                self.table_name = "test_backend_ops"
+                db_fields = TEST_BACKEND_OPS_DB_FIELDS
+                db_types = TEST_BACKEND_OPS_DB_TYPES
+            else:
+                assert False
 
-        self.cursor.execute(f"CREATE TABLE {self.table_name}({', '.join(' '.join(x) for x in zip(db_fields, db_types))});")
+            self.cursor.execute(f"CREATE TABLE {self.table_name}({', '.join(' '.join(x) for x in zip(db_fields, db_types))});")
 
     def _builds_init(self):
         if self.connection:
@@ -397,9 +398,6 @@ class LlamaBenchDataSQLite3(LlamaBenchData):
 
 class LlamaBenchDataSQLite3File(LlamaBenchDataSQLite3):
     def __init__(self, data_file: str, tool: Any):
-        super().__init__(tool)
-
-        self.connection.close()
         self.connection = sqlite3.connect(data_file)
         self.cursor = self.connection.cursor()
 
@@ -411,27 +409,28 @@ class LlamaBenchDataSQLite3File(LlamaBenchDataSQLite3):
         if tool is None:
             if "llama_bench" in table_names:
                 self.table_name = "llama_bench"
-                self.tool = "llama-bench"
+                tool = "llama-bench"
             elif "test_backend_ops" in table_names:
                 self.table_name = "test_backend_ops"
-                self.tool = "test-backend-ops"
+                tool = "test-backend-ops"
             else:
                 raise RuntimeError(f"No suitable table found in database. Available tables: {table_names}")
         elif tool == "llama-bench":
             if "llama_bench" in table_names:
                 self.table_name = "llama_bench"
-                self.tool = "llama-bench"
+                tool = "llama-bench"
             else:
                 raise RuntimeError(f"Table 'test' not found for tool 'llama-bench'. Available tables: {table_names}")
         elif tool == "test-backend-ops":
             if "test_backend_ops" in table_names:
                 self.table_name = "test_backend_ops"
-                self.tool = "test-backend-ops"
+                tool = "test-backend-ops"
             else:
                 raise RuntimeError(f"Table 'test_backend_ops' not found for tool 'test-backend-ops'. Available tables: {table_names}")
         else:
             raise RuntimeError(f"Unknown tool: {tool}")
 
+        super().__init__(tool)
         self._builds_init()
 
     @staticmethod
@@ -653,6 +652,8 @@ if not bench_data:
 if not bench_data.builds:
     raise RuntimeError(f"{input_file} does not contain any builds.")
 
+tool = bench_data.tool  # May have chosen a default if tool was None.
+
 
 hexsha8_baseline = name_baseline = None