]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
[SYCL] get MAX_MEM_ALLOC from device property (#5270)
authorMeng, Hengyu <redacted>
Fri, 2 Feb 2024 07:54:14 +0000 (15:54 +0800)
committerGitHub <redacted>
Fri, 2 Feb 2024 07:54:14 +0000 (15:54 +0800)
* get max alloc size from device prop

* fix macro typo

ggml-sycl.cpp

index e8ba483538c11f90f010f87bcca9f3aec3de3eb9..4ee2eed387bc03b5dd2a58e251532cd94728d88b 100644 (file)
@@ -337,6 +337,7 @@ namespace dpct
         }
         size_t get_global_mem_size() const { return _global_mem_size; }
         size_t get_local_mem_size() const { return _local_mem_size; }
+        size_t get_max_mem_alloc_size() const { return _max_mem_alloc_size; }
         /// Returns the maximum clock rate of device's global memory in kHz. If
         /// compiler does not support this API then returns default value 3200000 kHz.
         unsigned int get_memory_clock_rate() const { return _memory_clock_rate; }
@@ -398,6 +399,10 @@ namespace dpct
         {
             _local_mem_size = local_mem_size;
         }
+        void set_max_mem_alloc_size(size_t max_mem_alloc_size)
+        {
+            _max_mem_alloc_size = max_mem_alloc_size;
+        }
         void set_max_work_group_size(int max_work_group_size)
         {
             _max_work_group_size = max_work_group_size;
@@ -465,6 +470,7 @@ namespace dpct
         int _max_register_size_per_work_group;
         size_t _global_mem_size;
         size_t _local_mem_size;
+        size_t _max_mem_alloc_size;
         size_t _max_nd_range_size[3];
         int _max_nd_range_size_i[3];
         uint32_t _device_id;
@@ -516,6 +522,7 @@ namespace dpct
             dev.get_info<sycl::info::device::max_work_group_size>());
         prop.set_global_mem_size(dev.get_info<sycl::info::device::global_mem_size>());
         prop.set_local_mem_size(dev.get_info<sycl::info::device::local_mem_size>());
+        prop.set_max_mem_alloc_size(dev.get_info<sycl::info::device::max_mem_alloc_size>());
 
 #if (defined(SYCL_EXT_INTEL_DEVICE_INFO) && SYCL_EXT_INTEL_DEVICE_INFO >= 6)
         if (dev.has(sycl::aspect::ext_intel_memory_clock_rate))
@@ -644,6 +651,11 @@ namespace dpct
             return get_device_info().get_global_mem_size();
         }
 
+        size_t get_max_mem_alloc_size() const
+        {
+            return get_device_info().get_max_mem_alloc_size();
+        }
+
         /// Get the number of bytes of free and total memory on the SYCL device.
         /// \param [out] free_memory The number of bytes of free memory on the SYCL device.
         /// \param [out] total_memory The number of bytes of total memory on the SYCL device.
@@ -11311,10 +11323,10 @@ void ggml_init_sycl() try {
         GGML_ASSERT(g_all_sycl_device_count <= GGML_SYCL_MAX_DEVICES);
         int64_t total_vram = 0;
 
-#if defined(GGML_SYCL_FP16)
-        fprintf(stderr, "%s: GGML_SYCL_FP16:   yes\n", __func__);
+#if defined(GGML_SYCL_F16)
+        fprintf(stderr, "%s: GGML_SYCL_F16:   yes\n", __func__);
 #else
-        fprintf(stderr, "%s: GGML_SYCL_FP16:   no\n", __func__);
+        fprintf(stderr, "%s: GGML_SYCL_F16:   no\n", __func__);
 #endif
 
 
@@ -14788,6 +14800,12 @@ static size_t ggml_backend_sycl_buffer_type_get_alignment(ggml_backend_buffer_ty
     UNUSED(buft);
 }
 
+static size_t ggml_backend_sycl_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
+    return dpct::get_current_device().get_max_mem_alloc_size();
+
+    UNUSED(buft);
+}
+
 static size_t ggml_backend_sycl_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
     int64_t row_low = 0;
     int64_t row_high = ggml_nrows(tensor);
@@ -14818,7 +14836,7 @@ static ggml_backend_buffer_type_i ggml_backend_sycl_buffer_type_interface = {
     /* .get_name         = */ ggml_backend_sycl_buffer_type_name,
     /* .alloc_buffer     = */ ggml_backend_sycl_buffer_type_alloc_buffer,
     /* .get_alignment    = */ ggml_backend_sycl_buffer_type_get_alignment,
-    /* .get_max_size     = */ NULL, // TODO: return device.maxBufferLength
+    /* .get_max_size     = */ ggml_backend_sycl_buffer_type_get_max_size,
     /* .get_alloc_size   = */ ggml_backend_sycl_buffer_type_get_alloc_size,
     /* .supports_backend = */ ggml_backend_sycl_buffer_type_supports_backend,
     /* .is_host          = */ nullptr,