]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
Refactor validation and enumeration platform checks into functions to clean up ggml_v...
author0cc4m <redacted>
Wed, 14 Feb 2024 19:57:17 +0000 (20:57 +0100)
committerGeorgi Gerganov <redacted>
Thu, 22 Feb 2024 13:12:36 +0000 (15:12 +0200)
ggml-vulkan.cpp

index 37123ac8f0c4dbc5d5ded37fd8ec3b644e61016d..4e5eaff15110bfcaa7ea5be4eeb4115e9a743eec 100644 (file)
@@ -1091,7 +1091,10 @@ static void ggml_vk_print_gpu_info(size_t idx) {
     }
 }
 
-static void ggml_vk_instance_init() {
+static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
+static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
+
+void ggml_vk_instance_init() {
     if (vk_instance_initialized) {
         return;
     }
@@ -1102,54 +1105,40 @@ static void ggml_vk_instance_init() {
     vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION };
 
     const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties();
-#ifdef __APPLE__
-    bool portability_enumeration_ext = false;
-    // Check for portability enumeration extension for MoltenVK support
-    for (const auto& properties : instance_extensions) {
-        if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
-            portability_enumeration_ext = true;
-            break;
-        }
+    const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions);
+    const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);
+
+    std::vector<const char*> layers;
+
+    if (validation_ext) {
+        layers.push_back("VK_LAYER_KHRONOS_validation");
     }
-    if (!portability_enumeration_ext) {
-        std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
+    std::vector<const char*> extensions;
+    if (validation_ext) {
+        extensions.push_back("VK_EXT_validation_features");
     }
-#endif
-
-    std::vector<const char*> layers = {
-#ifdef GGML_VULKAN_VALIDATE
-        "VK_LAYER_KHRONOS_validation",
-#endif
-    };
-    std::vector<const char*> extensions = {
-#ifdef GGML_VULKAN_VALIDATE
-        "VK_EXT_validation_features",
-#endif
-    };
-#ifdef __APPLE__
     if (portability_enumeration_ext) {
         extensions.push_back("VK_KHR_portability_enumeration");
     }
-#endif
     vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions);
-#ifdef __APPLE__
     if (portability_enumeration_ext) {
         instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR;
     }
-#endif
 
+    std::vector<vk::ValidationFeatureEnableEXT> features_enable;
+    vk::ValidationFeaturesEXT validation_features;
 
-#ifdef GGML_VULKAN_VALIDATE
-    const std::vector<vk::ValidationFeatureEnableEXT> features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices };
-    vk::ValidationFeaturesEXT validation_features = {
-        features_enable,
-        {},
-    };
-    validation_features.setPNext(nullptr);
-    instance_create_info.setPNext(&validation_features);
+    if (validation_ext) {
+        features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices };
+        validation_features = {
+            features_enable,
+            {},
+        };
+        validation_features.setPNext(nullptr);
+        instance_create_info.setPNext(&validation_features);
 
-    std::cerr << "ggml_vulkan: Validation layers enabled" << std::endl;
-#endif
+        std::cerr << "ggml_vulkan: Validation layers enabled" << std::endl;
+    }
     vk_instance.instance = vk::createInstance(instance_create_info);
 
     memset(vk_instance.initialized, 0, sizeof(bool) * GGML_VK_MAX_DEVICES);
@@ -5329,6 +5318,42 @@ GGML_CALL int ggml_backend_vk_reg_devices() {
     return vk_instance.device_indices.size();
 }
 
+// Extension availability
+static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions) {
+#ifdef GGML_VULKAN_VALIDATE
+    bool portability_enumeration_ext = false;
+    // Check for portability enumeration extension for MoltenVK support
+    for (const auto& properties : instance_extensions) {
+        if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
+            return true;
+        }
+    }
+    if (!portability_enumeration_ext) {
+        std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
+    }
+#endif
+    return false;
+
+    UNUSED(instance_extensions);
+}
+static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions) {
+#ifdef __APPLE__
+    bool portability_enumeration_ext = false;
+    // Check for portability enumeration extension for MoltenVK support
+    for (const auto& properties : instance_extensions) {
+        if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
+            return true;
+        }
+    }
+    if (!portability_enumeration_ext) {
+        std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
+    }
+#endif
+    return false;
+
+    UNUSED(instance_extensions);
+}
+
 // checks
 
 #ifdef GGML_VULKAN_CHECK_RESULTS