]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda : improve cuda pool efficiency using virtual memory (#4606)
authorslaren <redacted>
Sun, 24 Dec 2023 13:34:22 +0000 (14:34 +0100)
committerGitHub <redacted>
Sun, 24 Dec 2023 13:34:22 +0000 (14:34 +0100)
* cuda : improve cuda pool efficiency using virtual memory

* fix mixtral

* fix cmake build

* check for vmm support, disable for hip

ggml-ci

* fix hip build

* clarify granularity

* move all caps to g_device_caps

* refactor error checking

* add cuda_pool_alloc, refactor most pool allocations

ggml-ci

* fix hip build

* CUBLAS_TF32_TENSOR_OP_MATH is not a macro

* more hip crap

* llama : fix msvc warnings

* ggml : fix msvc warnings

* minor

* minor

* cuda : fallback to CPU on host buffer alloc fail

* Update ggml-cuda.cu

Co-authored-by: Johannes Gäßler <redacted>
* Update ggml-cuda.cu

Co-authored-by: Johannes Gäßler <redacted>
* ensure allocations are always aligned

* act_size -> actual_size

---------

Co-authored-by: Johannes Gäßler <redacted>
CMakeLists.txt
Makefile
ggml-backend.c
ggml-cuda.cu
ggml.c
ggml.h
llama.cpp
tests/test-grad0.cpp

index 6fc6508c598ff1bf64885c63c4af105ce85a1396..545aab267dbec8e19b8a53d321dc900615ea2542 100644 (file)
@@ -302,6 +302,8 @@ if (LLAMA_CUBLAS)
             set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
         endif()
 
+        set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cuda_driver)
+
     if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
         # 52 == lowest CUDA 12 standard
         # 60 == f16 CUDA intrinsics
index cb5a4e948e5e39186a359d31d41d7436873d631a..28c6d79bcd7d507007ad7e1ee226430d93a1a0df 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -367,17 +367,15 @@ endif # LLAMA_BLIS
 
 ifdef LLAMA_CUBLAS
        MK_CPPFLAGS  += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include -I/usr/local/cuda/targets/aarch64-linux/include
-       MK_LDFLAGS   += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib -L/usr/local/cuda/targets/aarch64-linux/lib
+       MK_LDFLAGS   += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib -L/usr/local/cuda/targets/aarch64-linux/lib -L/usr/lib/wsl/lib
        OBJS         += ggml-cuda.o
        MK_NVCCFLAGS  = -use_fast_math
 ifndef JETSON_EOL_MODULE_DETECT
        MK_NVCCFLAGS += --forward-unknown-to-host-compiler
 endif # JETSON_EOL_MODULE_DETECT
-
 ifdef LLAMA_DEBUG
        MK_NVCCFLAGS += -lineinfo
-endif
-
+endif # LLAMA_DEBUG
 ifdef LLAMA_CUDA_NVCC
        NVCC = $(LLAMA_CUDA_NVCC)
 else
index 0c8c9ec430475aca51fa668f14a2fbc3ef6b0af9..526ce732be5b5d8168dd9405a250fcb0d26c4816 100644 (file)
@@ -297,7 +297,7 @@ static void ggml_backend_registry_init(void) {
 void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data) {
     GGML_ASSERT(ggml_backend_registry_count < GGML_MAX_BACKENDS_REG);
 
-    int id = ggml_backend_registry_count;
+    size_t id = ggml_backend_registry_count;
 
     ggml_backend_registry[id] = (struct ggml_backend_reg) {
         /* .name                = */ {0},
@@ -330,6 +330,8 @@ size_t ggml_backend_reg_find_by_name(const char * name) {
             return i;
         }
     }
+
+    // not found
     return SIZE_MAX;
 }
 
@@ -340,15 +342,15 @@ ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str)
     const char * params = strchr(backend_str, ':');
     char backend_name[128];
     if (params == NULL) {
-        strcpy(backend_name, backend_str);
+        snprintf(backend_name, sizeof(backend_name), "%s", backend_str);
         params = "";
     } else {
-        strncpy(backend_name, backend_str, params - backend_str);
-        backend_name[params - backend_str] = '\0';
+        snprintf(backend_name, sizeof(backend_name), "%.*s", (int)(params - backend_str), backend_str);
         params++;
     }
 
     size_t backend_i = ggml_backend_reg_find_by_name(backend_name);
+
     if (backend_i == SIZE_MAX) {
         fprintf(stderr, "%s: backend %s not found\n", __func__, backend_name);
         return NULL;
@@ -396,18 +398,12 @@ static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
 }
 
 static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
-    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
-    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
-
     memcpy((char *)tensor->data + offset, data, size);
 
     GGML_UNUSED(buffer);
 }
 
 static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
-    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
-    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
-
     memcpy(data, (const char *)tensor->data + offset, size);
 
     GGML_UNUSED(buffer);
index f9830328be51bbf8f1d283e90d9eed0c213b6b50..ac3b3c14d53df663dda113f0b40107f6ddb41f33 100644 (file)
 #define cudaStream_t hipStream_t
 #define cudaSuccess hipSuccess
 #define __trap abort
+#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
+#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
+#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
+#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
+#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
+#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
+#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
+#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
+#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
 #else
 #include <cuda_runtime.h>
+#include <cuda.h>
 #include <cublas_v2.h>
 #include <cuda_fp16.h>
-// CUDA 10.2 does not have these macro definitions.
-#ifndef CUBLAS_TF32_TENSOR_OP_MATH
+
+#if CUDART_VERSION < 11020
 #define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
 #define CUBLAS_COMPUTE_16F CUDA_R_16F
 #define CUBLAS_COMPUTE_32F CUDA_R_32F
 #define cublasComputeType_t cudaDataType_t
-#endif
+#endif // CUDART_VERSION < 11020
+
 #endif // defined(GGML_USE_HIPBLAS)
 
 #include "ggml-cuda.h"
@@ -200,45 +211,45 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
 
 static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
 
-#define CUDA_CHECK(err)                                                                 \
-    do {                                                                                \
-        cudaError_t err_ = (err);                                                       \
-        if (err_ != cudaSuccess) {                                                      \
-            int id;                                                                     \
-            cudaGetDevice(&id);                                                         \
-            fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
-                cudaGetErrorString(err_));                                              \
-            fprintf(stderr, "current device: %d\n", id);                                \
-            GGML_ASSERT(!"CUDA error");                                                 \
-        }                                                                               \
-    } while (0)
-
 #if CUDART_VERSION >= 12000
-#define CUBLAS_CHECK(err)                                                               \
-    do {                                                                                \
-        cublasStatus_t err_ = (err);                                                    \
-        if (err_ != CUBLAS_STATUS_SUCCESS) {                                            \
-            int id;                                                                     \
-            cudaGetDevice(&id);                                                         \
-            fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n",                         \
-                    err_, __FILE__, __LINE__, cublasGetStatusString(err_));             \
-            fprintf(stderr, "current device: %d\n", id);                                \
-            GGML_ASSERT(!"cuBLAS error");                                               \
-        }                                                                               \
-    } while (0)
+    static const char * cublas_get_error_str(const cublasStatus_t err) {
+        return cublasGetStatusString(err);
+    }
 #else
-#define CUBLAS_CHECK(err)                                                               \
-    do {                                                                                \
-        cublasStatus_t err_ = (err);                                                    \
-        if (err_ != CUBLAS_STATUS_SUCCESS) {                                            \
-            int id;                                                                     \
-            cudaGetDevice(&id);                                                         \
-            fprintf(stderr, "\ncuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);  \
-            fprintf(stderr, "current device: %d\n", id);                                \
-            GGML_ASSERT(!"cuBLAS error");                                               \
-        }                                                                               \
-    } while (0)
-#endif // CUDART_VERSION >= 11
+    static const char * cublas_get_error_str(const cublasStatus_t err) {
+        switch (err) {
+            case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
+            case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
+            case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
+            case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
+            case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
+            case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
+            case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
+            case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
+            case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
+            default: return "unknown error";
+        }
+    }
+#endif // CUDART_VERSION >= 12000
+
+[[noreturn]]
+static void ggml_cuda_error(const char * stmt, const char * func, const char * file, const int line, const char * msg) {
+    fprintf(stderr, "CUDA error: %s: %s\n", stmt, msg);
+    fprintf(stderr, "  in function %s at %s:%d\n", func, file, line);
+    GGML_ASSERT(!"CUDA error");
+}
+
+#define CUDA_CHECK(err)   do { auto err_ = (err); if (err_ != cudaSuccess)           ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cudaGetErrorString(err_));   } while (0)
+#define CUBLAS_CHECK(err) do { auto err_ = (err); if (err_ != CUBLAS_STATUS_SUCCESS) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cublas_get_error_str(err_)); } while (0)
+
+#if !defined(GGML_USE_HIPBLAS)
+static const char * cu_get_error_str(CUresult err) {
+    const char * err_str;
+    cuGetErrorString(err, &err_str);
+    return err_str;
+}
+#define CU_CHECK(err)     do { auto err_ = (err); if (err_ != CUDA_SUCCESS)          ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cu_get_error_str(err_));     } while (0)
+#endif
 
 #if CUDART_VERSION >= 11100
 #define GGML_CUDA_ASSUME(x) __builtin_assume(x)
@@ -516,9 +527,17 @@ inline cudaError_t ggml_cuda_set_device(const int device) {
 
 static int g_device_count = -1;
 static int g_main_device = 0;
-static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
 static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
 
+struct cuda_device_capabilities {
+    int     cc;                 // compute capability
+    bool    vmm;                // virtual memory support
+    size_t  vmm_granularity;    // granularity of virtual memory
+};
+
+static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} };
+
+
 static void * g_scratch_buffer = nullptr;
 static size_t g_scratch_size = 0; // disabled by default
 static size_t g_scratch_offset = 0;
@@ -5875,7 +5894,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
 
     int id;
     CUDA_CHECK(cudaGetDevice(&id));
-    const int compute_capability = g_compute_capabilities[id];
+    const int compute_capability = g_device_caps[id].cc;
 
     int mmq_x, mmq_y, nwarps;
     if (compute_capability >= CC_RDNA2) {
@@ -5920,7 +5939,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
 
     int id;
     CUDA_CHECK(cudaGetDevice(&id));
-    const int compute_capability = g_compute_capabilities[id];
+    const int compute_capability = g_device_caps[id].cc;
 
     int mmq_x, mmq_y, nwarps;
     if (compute_capability >= CC_RDNA2) {
@@ -5965,7 +5984,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
 
     int id;
     CUDA_CHECK(cudaGetDevice(&id));
-    const int compute_capability = g_compute_capabilities[id];
+    const int compute_capability = g_device_caps[id].cc;
 
     int mmq_x, mmq_y, nwarps;
     if (compute_capability >= CC_RDNA2) {
@@ -6010,7 +6029,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
 
     int id;
     CUDA_CHECK(cudaGetDevice(&id));
-    const int compute_capability = g_compute_capabilities[id];
+    const int compute_capability = g_device_caps[id].cc;
 
     int mmq_x, mmq_y, nwarps;
     if (compute_capability >= CC_RDNA2) {
@@ -6055,7 +6074,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
 
     int id;
     CUDA_CHECK(cudaGetDevice(&id));
-    const int compute_capability = g_compute_capabilities[id];
+    const int compute_capability = g_device_caps[id].cc;
 
     int mmq_x, mmq_y, nwarps;
     if (compute_capability >= CC_RDNA2) {
@@ -6100,7 +6119,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
 
     int id;
     CUDA_CHECK(cudaGetDevice(&id));
-    const int compute_capability = g_compute_capabilities[id];
+    const int compute_capability = g_device_caps[id].cc;
 
     int mmq_x, mmq_y, nwarps;
     if (compute_capability >= CC_RDNA2) {
@@ -6147,7 +6166,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
 
     int id;
     CUDA_CHECK(cudaGetDevice(&id));
-    const int compute_capability = g_compute_capabilities[id];
+    const int compute_capability = g_device_caps[id].cc;
 
     int mmq_x, mmq_y, nwarps;
     if (compute_capability >= CC_RDNA2) {
@@ -6193,7 +6212,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
 
     int id;
     CUDA_CHECK(cudaGetDevice(&id));
-    const int compute_capability = g_compute_capabilities[id];
+    const int compute_capability = g_device_caps[id].cc;
 
     int mmq_x, mmq_y, nwarps;
     if (compute_capability >= CC_RDNA2) {
@@ -6238,7 +6257,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
 
     int id;
     CUDA_CHECK(cudaGetDevice(&id));
-    const int compute_capability = g_compute_capabilities[id];
+    const int compute_capability = g_device_caps[id].cc;
 
     int mmq_x, mmq_y, nwarps;
     if (compute_capability >= CC_RDNA2) {
@@ -6283,7 +6302,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
 
     int id;
     CUDA_CHECK(cudaGetDevice(&id));
-    const int compute_capability = g_compute_capabilities[id];
+    const int compute_capability = g_device_caps[id].cc;
 
     int mmq_x, mmq_y, nwarps;
     if (compute_capability >= CC_RDNA2) {
@@ -6543,21 +6562,24 @@ struct scoped_spin_lock {
     scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
 };
 
+static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
+
+// #define DEBUG_CUDA_MALLOC
 struct cuda_buffer {
     void * ptr = nullptr;
     size_t size = 0;
 };
 
 static cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS];
-static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
+static size_t g_cuda_pool_size[GGML_CUDA_MAX_DEVICES] = {0};
 
-static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
+static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) {
     scoped_spin_lock lock(g_cuda_pool_lock);
     int id;
     CUDA_CHECK(cudaGetDevice(&id));
 #ifdef DEBUG_CUDA_MALLOC
     int nnz = 0;
-    size_t max_size = 0, tot_size = 0;
+    size_t max_size = 0;
 #endif
     size_t best_diff = 1ull << 36;
     int ibest = -1;
@@ -6566,7 +6588,6 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
         if (b.ptr != nullptr) {
 #ifdef DEBUG_CUDA_MALLOC
             ++nnz;
-            tot_size += b.size;
             if (b.size > max_size) max_size = b.size;
 #endif
             if (b.size >= size) {
@@ -6593,19 +6614,20 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
         b.size = 0;
         return ptr;
     }
-#ifdef DEBUG_CUDA_MALLOC
-    fprintf(stderr, "%s: %d buffers, max_size = %u MB, tot_size = %u MB, requested %u MB\n", __func__, nnz,
-            (uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024));
-#endif
     void * ptr;
     size_t look_ahead_size = (size_t) (1.05 * size);
     look_ahead_size = 256 * ((look_ahead_size + 255)/256);
     CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
     *actual_size = look_ahead_size;
+    g_cuda_pool_size[id] += look_ahead_size;
+#ifdef DEBUG_CUDA_MALLOC
+    fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz,
+            (uint32_t)(max_size/1024/1024), (uint32_t)(g_cuda_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024));
+#endif
     return ptr;
 }
 
-static void ggml_cuda_pool_free(void * ptr, size_t size) {
+static void ggml_cuda_pool_free_leg(void * ptr, size_t size) {
     scoped_spin_lock lock(g_cuda_pool_lock);
     int id;
     CUDA_CHECK(cudaGetDevice(&id));
@@ -6620,7 +6642,151 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
     }
     fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
     CUDA_CHECK(cudaFree(ptr));
+    g_cuda_pool_size[id] -= size;
+}
+
+#if !defined(GGML_USE_HIPBLAS)
+// pool with virtual memory
+static std::vector<CUmemGenericAllocationHandle> g_cuda_pool_handles[GGML_CUDA_MAX_DEVICES];
+static CUdeviceptr g_cuda_pool_addr[GGML_CUDA_MAX_DEVICES] = {0};
+static size_t g_cuda_pool_used[GGML_CUDA_MAX_DEVICES] = {0};
+static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 36; // 64 GB
+
+static void * ggml_cuda_pool_malloc_vmm(size_t size, size_t * actual_size) {
+    scoped_spin_lock lock(g_cuda_pool_lock);
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+
+    // round up the allocation size to the alignment to ensure that all allocations are aligned for all data types
+    const size_t alignment = 128;
+    size = alignment * ((size + alignment - 1) / alignment);
+
+    size_t avail = g_cuda_pool_size[id] - g_cuda_pool_used[id];
+
+    if (size > avail) {
+        // round up to the next multiple of the granularity
+        size_t reserve_size = size - avail;
+        const size_t granularity = g_device_caps[id].vmm_granularity;
+        reserve_size = granularity * ((reserve_size + granularity - 1) / granularity);
+
+        GGML_ASSERT(g_cuda_pool_size[id] + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);
+
+        // allocate more physical memory
+        CUmemAllocationProp prop = {};
+        prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
+        prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
+        prop.location.id = id;
+        CUmemGenericAllocationHandle handle;
+        CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0));
+
+        // reserve virtual address space (if not already reserved)
+        if (g_cuda_pool_addr[id] == 0) {
+            CU_CHECK(cuMemAddressReserve(&g_cuda_pool_addr[id], CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0));
+        }
+
+        // map at the end of the pool
+        CU_CHECK(cuMemMap(g_cuda_pool_addr[id] + g_cuda_pool_size[id], reserve_size, 0, handle, 0));
+
+        // set access
+        CUmemAccessDesc access = {};
+        access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
+        access.location.id = id;
+        access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
+        CU_CHECK(cuMemSetAccess(g_cuda_pool_addr[id] + g_cuda_pool_size[id], reserve_size, &access, 1));
+
+        // add to the pool
+        g_cuda_pool_handles[id].push_back(handle);
+        g_cuda_pool_size[id] += reserve_size;
+
+        //printf("cuda pool[%d]: size increased to %llu MB (reserved %llu MB)\n",
+        //       id, (unsigned long long) (g_cuda_pool_size[id]/1024/1024),
+        //       (unsigned long long) (reserve_size/1024/1024));
+    }
+
+    GGML_ASSERT(g_cuda_pool_addr[id] != 0);
+
+    void * ptr = (void *) (g_cuda_pool_addr[id] + g_cuda_pool_used[id]);
+    *actual_size = size;
+    g_cuda_pool_used[id] += size;
+
+#ifdef DEBUG_CUDA_MALLOC
+    printf("cuda pool[%d]: allocated %llu bytes at %llx [%s]\n", id, (unsigned long long) size, ptr);
+#endif
+
+    return ptr;
+}
+
+static void ggml_cuda_pool_free_vmm(void * ptr, size_t size) {
+    scoped_spin_lock lock(g_cuda_pool_lock);
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+
+#ifdef DEBUG_CUDA_MALLOC
+    printf("cuda pool[%d]: freed %llu bytes at %llx\n", id, (unsigned long long) size, ptr);
+#endif
+
+    g_cuda_pool_used[id] -= size;
+
+    // all deallocations must be in reverse order of the allocations
+    GGML_ASSERT(ptr == (void *) (g_cuda_pool_addr[id] + g_cuda_pool_used[id]));
+}
+
+static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+    if (g_device_caps[id].vmm) {
+        return ggml_cuda_pool_malloc_vmm(size, actual_size);
+    } else {
+        return ggml_cuda_pool_malloc_leg(size, actual_size);
+    }
+}
+
+static void ggml_cuda_pool_free(void * ptr, size_t size) {
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+    if (g_device_caps[id].vmm) {
+        ggml_cuda_pool_free_vmm(ptr, size);
+    } else {
+        ggml_cuda_pool_free_leg(ptr, size);
+    }
 }
+#else
+#define ggml_cuda_pool_malloc ggml_cuda_pool_malloc_leg
+#define ggml_cuda_pool_free ggml_cuda_pool_free_leg
+#endif // !defined(GGML_USE_HIPBLAS)
+
+template<typename T>
+struct cuda_pool_alloc {
+    T * ptr = nullptr;
+    size_t actual_size = 0;
+
+    // size is in number of elements
+    T * alloc(size_t size) {
+        GGML_ASSERT(ptr == nullptr);
+        ptr = (T *) ggml_cuda_pool_malloc(size * sizeof(T), &this->actual_size);
+        return ptr;
+    }
+
+    cuda_pool_alloc(size_t size) {
+        alloc(size);
+    }
+
+    ~cuda_pool_alloc() {
+        if (ptr != nullptr) {
+            ggml_cuda_pool_free(ptr, actual_size);
+        }
+    }
+
+    T * get() {
+        return ptr;
+    }
+
+    cuda_pool_alloc() = default;
+    cuda_pool_alloc(const cuda_pool_alloc &) = delete;
+    cuda_pool_alloc(cuda_pool_alloc &&) = delete;
+    cuda_pool_alloc& operator=(const cuda_pool_alloc &) = delete;
+    cuda_pool_alloc& operator=(cuda_pool_alloc &&) = delete;
+};
 
 static bool g_cublas_loaded = false;
 
@@ -6660,16 +6826,33 @@ void ggml_init_cublas() {
 #endif
         fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count);
         for (int id = 0; id < g_device_count; ++id) {
+            int device_vmm = 0;
+
+#if !defined(GGML_USE_HIPBLAS)
+            CUdevice device;
+            CU_CHECK(cuDeviceGet(&device, id));
+            CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
+
+            if (device_vmm) {
+                CUmemAllocationProp alloc_prop = {};
+                alloc_prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
+                alloc_prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
+                alloc_prop.location.id = id;
+                CU_CHECK(cuMemGetAllocationGranularity(&g_device_caps[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
+            }
+#endif // !defined(GGML_USE_HIPBLAS)
+            g_device_caps[id].vmm = !!device_vmm;
+
             cudaDeviceProp prop;
             CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
-            fprintf(stderr, "  Device %d: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor);
+            fprintf(stderr, "  Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
 
             g_tensor_split[id] = total_vram;
             total_vram += prop.totalGlobalMem;
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-            g_compute_capabilities[id] = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
+            g_device_caps[id].cc = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
 #else
-            g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
+            g_device_caps[id].cc = 100*prop.major + 10*prop.minor;
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
         }
         for (int id = 0; id < g_device_count; ++id) {
@@ -7178,11 +7361,11 @@ static int64_t get_row_rounding(ggml_type type) {
     int64_t max_compute_capability = INT_MIN;
     for (int64_t id = 0; id < g_device_count; ++id) {
         if (g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
-            if (min_compute_capability > g_compute_capabilities[id]) {
-                min_compute_capability = g_compute_capabilities[id];
+            if (min_compute_capability > g_device_caps[id].cc) {
+                min_compute_capability = g_device_caps[id].cc;
             }
-            if (max_compute_capability < g_compute_capabilities[id]) {
-                max_compute_capability = g_compute_capabilities[id];
+            if (max_compute_capability < g_device_caps[id].cc) {
+                max_compute_capability = g_device_caps[id].cc;
             }
         }
     }
@@ -7297,8 +7480,8 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
 
     // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
 #ifdef GGML_CUDA_F16
-    size_t ash;
-    dfloat * src1_dfloat = nullptr; // dfloat == half
+    cuda_pool_alloc<half> src1_dfloat_a;
+    half * src1_dfloat = nullptr; // dfloat == half
 
     bool src1_convert_f16 =
         src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
@@ -7306,7 +7489,7 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
         src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
 
     if (src1_convert_f16) {
-        src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
+        src1_dfloat = src1_dfloat_a.alloc(ne00);
         ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00,
                                 ne00, 1, sizeof(float), 0, 0,
                                 ne00, 1, sizeof(half),  0, 0, stream);
@@ -7354,12 +7537,6 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
             break;
     }
 
-#ifdef GGML_CUDA_F16
-    if (src1_convert_f16) {
-        ggml_cuda_pool_free(src1_dfloat, ash);
-    }
-#endif // GGML_CUDA_F16
-
     (void) src1;
     (void) dst;
     (void) src1_ddq_i;
@@ -7390,33 +7567,30 @@ inline void ggml_cuda_op_mul_mat_cublas(
     // ldc == nrows of the matrix that cuBLAS writes into
     int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
 
-    const int compute_capability = g_compute_capabilities[id];
+    const int compute_capability = g_device_caps[id].cc;
 
     if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
         // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
-        half * src0_as_f16 = nullptr;
-        size_t src0_as = 0;
+        cuda_pool_alloc<half> src0_as_f16;
         if (src0->type != GGML_TYPE_F16) {
             const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
             GGML_ASSERT(to_fp16_cuda != nullptr);
             size_t ne = row_diff*ne00;
-            src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as);
-            to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
+            src0_as_f16.alloc(ne);
+            to_fp16_cuda(src0_dd_i, src0_as_f16.get(), ne, stream);
         }
-        const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
+        const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();
 
-        half * src1_as_f16 = nullptr;
-        size_t src1_as = 0;
+        cuda_pool_alloc<half> src1_as_f16;
         if (src1->type != GGML_TYPE_F16) {
             const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
             GGML_ASSERT(to_fp16_cuda != nullptr);
             size_t ne = src1_ncols*ne10;
-            src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
-            to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
+            src1_as_f16.alloc(ne);
+            to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
         }
-        const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
-        size_t dst_as = 0;
-        half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
+        const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
+        cuda_pool_alloc<half> dst_f16(row_diff*src1_ncols);
 
         const half alpha_f16 = 1.0f;
         const half beta_f16 = 0.0f;
@@ -7425,36 +7599,25 @@ inline void ggml_cuda_op_mul_mat_cublas(
         CUBLAS_CHECK(
             cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
                     row_diff, src1_ncols, ne10,
-                    &alpha_f16, src0_ptr, CUDA_R_16F, ne00,
-                                src1_ptr, CUDA_R_16F, ne10,
-                    &beta_f16,   dst_f16, CUDA_R_16F, ldc,
+                    &alpha_f16, src0_ptr,       CUDA_R_16F, ne00,
+                                src1_ptr,       CUDA_R_16F, ne10,
+                    &beta_f16,   dst_f16.get(), CUDA_R_16F, ldc,
                     CUBLAS_COMPUTE_16F,
                     CUBLAS_GEMM_DEFAULT_TENSOR_OP));
 
         const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
-        to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
-
-        ggml_cuda_pool_free(dst_f16, dst_as);
-
-        if (src0_as != 0) {
-            ggml_cuda_pool_free(src0_as_f16, src0_as);
-        }
-
-        if (src1_as != 0) {
-            ggml_cuda_pool_free(src1_as_f16, src1_as);
-        }
+        to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
     }
     else {
-        float * src0_ddq_as_f32 = nullptr;
-        size_t src0_as = 0;
+        cuda_pool_alloc<float> src0_ddq_as_f32;
 
         if (src0->type != GGML_TYPE_F32) {
             const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
             GGML_ASSERT(to_fp32_cuda != nullptr);
-            src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
-            to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
+            src0_ddq_as_f32.alloc(row_diff*ne00);
+            to_fp32_cuda(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
         }
-        const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
+        const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
 
         const float alpha = 1.0f;
         const float beta = 0.0f;
@@ -7466,10 +7629,6 @@ inline void ggml_cuda_op_mul_mat_cublas(
                     &alpha, src0_ddf_i, ne00,
                             src1_ddf_i, ne10,
                     &beta,  dst_dd_i,   ldc));
-
-        if (src0_as != 0) {
-            ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
-        }
     }
 
     (void) dst;
@@ -7761,18 +7920,17 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
     float * src1_ddf = nullptr;
     float *  dst_ddf = nullptr;
 
-    // as = actual size
-    size_t src0_asf = 0;
-    size_t src1_asf = 0;
-    size_t  dst_asf = 0;
+    cuda_pool_alloc<float> src0_f;
+    cuda_pool_alloc<float> src1_f;
+    cuda_pool_alloc<float>  dst_f;
 
     ggml_cuda_set_device(g_main_device);
-    const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
+    cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
 
     if (src0_on_device) {
         src0_ddf = (float *) src0_extra->data_device[g_main_device];
     } else {
-        src0_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_asf);
+        src0_ddf = src0_f.alloc(ggml_nelements(src0));
         CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream));
     }
 
@@ -7780,14 +7938,14 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
         if (src1_on_device) {
             src1_ddf = (float *) src1_extra->data_device[g_main_device];
         } else {
-            src1_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf);
+            src1_ddf = src1_f.alloc(ggml_nelements(src1));
             CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf, src1, 0, 0, 0, nrows1, main_stream));
         }
     }
     if (dst_on_device) {
         dst_ddf = (float *) dst_extra->data_device[g_main_device];
     } else {
-        dst_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(dst), &dst_asf);
+        dst_ddf = dst_f.alloc(ggml_nelements(dst));
     }
 
     // do the computation
@@ -7799,16 +7957,6 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
         CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
     }
 
-    if (src0_asf > 0) {
-        ggml_cuda_pool_free(src0_ddf, src0_asf);
-    }
-    if (src1_asf > 0) {
-        ggml_cuda_pool_free(src1_ddf, src1_asf);
-    }
-    if (dst_asf > 0) {
-        ggml_cuda_pool_free(dst_ddf, dst_asf);
-    }
-
     if (dst->backend == GGML_BACKEND_CPU) {
         CUDA_CHECK(cudaDeviceSynchronize());
     }
@@ -8122,17 +8270,17 @@ static void ggml_cuda_op_mul_mat(
         CUDA_CHECK(ggml_cuda_set_device(id));
 
         // free buffers again when done
-        if (src0_as[id] > 0) {
-            ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
-        }
-        if (src1_asf[id] > 0) {
-            ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
+        if (dst_as[id] > 0) {
+            ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
         }
         if (src1_asq[id] > 0) {
             ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]);
         }
-        if (dst_as[id] > 0) {
-            ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
+        if (src1_asf[id] > 0) {
+            ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
+        }
+        if (src0_as[id] > 0) {
+            ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
         }
     }
 
@@ -8385,14 +8533,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
     const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
     GGML_ASSERT(to_fp16_cuda != nullptr);
 
-    size_t src1_as = 0;
-    half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
-    to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
+    cuda_pool_alloc<half> src1_as_f16(ne1);
+    to_fp16_cuda(src1_ddf, src1_as_f16.get(), ne1, main_stream);
 
-    size_t dst_as = 0;
-
-    half * dst_f16 = nullptr;
-    char * dst_t   = nullptr;
+    cuda_pool_alloc<half> dst_f16;
+    char * dst_t;
 
     cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
     cudaDataType_t      cu_data_type    = CUDA_R_16F;
@@ -8411,8 +8556,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
     const void * beta  = &beta_f16;
 
     if (dst->op_params[0] == GGML_PREC_DEFAULT) {
-        dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
-        dst_t   = (char *) dst_f16;
+        dst_t = (char *) dst_f16.alloc(ne);
 
         nbd2 /= sizeof(float) / sizeof(half);
         nbd3 /= sizeof(float) / sizeof(half);
@@ -8459,9 +8603,9 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
         CUBLAS_CHECK(
         cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
                 ne01, ne11, ne10,
-                alpha, (const char *) src0_as_f16, CUDA_R_16F,   nb01/sizeof(half),  src0->nb[2]/sizeof(half),  // strideA
-                       (const char *) src1_as_f16, CUDA_R_16F,   nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
-                beta,  (      char *)       dst_t, cu_data_type, ne01,                dst->nb[2]/sizeof(float), // strideC
+                alpha, (const char *) src0_as_f16,       CUDA_R_16F,   nb01/sizeof(half),  src0->nb[2]/sizeof(half),  // strideA
+                       (const char *) src1_as_f16.get(), CUDA_R_16F,   nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
+                beta,  (      char *)       dst_t,       cu_data_type, ne01,                dst->nb[2]/sizeof(float), // strideC
                 ne12*ne13,
                 cu_compute_type,
                 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -8469,19 +8613,13 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
         // use cublasGemmBatchedEx
         const int ne23 = ne12*ne13;
 
-        const void ** ptrs_src = nullptr;
-              void ** ptrs_dst = nullptr;
-
-        size_t ptrs_src_s = 0;
-        size_t ptrs_dst_s = 0;
-
-        ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
-        ptrs_dst = (      void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);
+        cuda_pool_alloc<const void *> ptrs_src(2*ne23);
+        cuda_pool_alloc<      void *> ptrs_dst(1*ne23);
 
         dim3 block_dims(ne13, ne12);
         k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
-                src0_as_f16, src1_as_f16, dst_t,
-                ptrs_src, ptrs_dst,
+                src0_as_f16, src1_as_f16.get(), dst_t,
+                ptrs_src.get(), ptrs_dst.get(),
                 ne12, ne13,
                 ne23,
                 nb02, nb03,
@@ -8493,30 +8631,19 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
         CUBLAS_CHECK(
         cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
                 ne01, ne11, ne10,
-                alpha, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F,   nb01/sizeof(half),
-                       (const void **) (ptrs_src + 1*ne23), CUDA_R_16F,   nb11/sizeof(float),
-                beta,  (      void **) (ptrs_dst + 0*ne23), cu_data_type, ne01,
+                alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F,   nb01/sizeof(half),
+                       (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F,   nb11/sizeof(float),
+                beta,  (      void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne01,
                 ne23,
                 cu_compute_type,
                 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
-
-        if (ptrs_src_s != 0) {
-            ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
-        }
-        if (ptrs_dst_s != 0) {
-            ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
-        }
     }
 #endif
 
     if (dst->op_params[0] == GGML_PREC_DEFAULT) {
         const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
-        to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
-
-        ggml_cuda_pool_free(dst_f16, dst_as);
+        to_fp32_cuda(dst_f16.get(), dst_ddf, ne, main_stream);
     }
-
-    ggml_cuda_pool_free(src1_as_f16, src1_as);
 }
 
 static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -8529,8 +8656,8 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
 
     int64_t min_compute_capability = INT_MAX;
     for (int64_t id = 0; id < g_device_count; ++id) {
-        if (min_compute_capability > g_compute_capabilities[id] && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
-            min_compute_capability = g_compute_capabilities[id];
+        if (min_compute_capability > g_device_caps[id].cc && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
+            min_compute_capability = g_device_caps[id].cc;
         }
     }
 
@@ -8843,12 +8970,11 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
             ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
         }
     } else {
-        size_t as_src1, as_dst;
-        char * src1_contiguous = (char *) ggml_cuda_pool_malloc(sizeof(float)*ggml_nelements(src1), &as_src1);
-        char *  dst_contiguous = (char *) ggml_cuda_pool_malloc(sizeof(float)*ggml_nelements(dst),  &as_dst);
+        cuda_pool_alloc<char> src1_contiguous(sizeof(float)*ggml_nelements(src1));
+        cuda_pool_alloc<char>  dst_contiguous(sizeof(float)*ggml_nelements(dst));
 
-        src1_row_extra.data_device[g_main_device] = src1_contiguous;
-        dst_row_extra.data_device[g_main_device]  =  dst_contiguous;
+        src1_row_extra.data_device[g_main_device] = src1_contiguous.get();
+        dst_row_extra.data_device[g_main_device]  =  dst_contiguous.get();
 
         const cudaMemcpyKind src1_kind = src1->backend == GGML_BACKEND_CPU ?
             cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
@@ -8868,7 +8994,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
 
                 GGML_ASSERT(row_id >= 0 && row_id < n_as);
 
-                CUDA_CHECK(cudaMemcpyAsync(src1_contiguous + num_src1_rows*nb11, src1_original + i01*nb11,
+                CUDA_CHECK(cudaMemcpyAsync(src1_contiguous.get() + num_src1_rows*nb11, src1_original + i01*nb11,
                                         nb11, src1_kind, stream));
                 num_src1_rows++;
             }
@@ -8900,14 +9026,11 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
 
                 GGML_ASSERT(row_id >= 0 && row_id < n_as);
 
-                CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous + num_src1_rows*nb1,
+                CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous.get() + num_src1_rows*nb1,
                                         nb1, dst_kind, stream));
                 num_src1_rows++;
             }
         }
-
-        ggml_cuda_pool_free(src1_contiguous, as_src1);
-        ggml_cuda_pool_free(dst_contiguous,  as_dst);
     }
 
     if (dst->backend == GGML_BACKEND_CPU) {
@@ -9678,8 +9801,10 @@ static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buff
 
 static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
     void * ptr = ggml_cuda_host_malloc(size);
+
     if (ptr == nullptr) {
-        return nullptr;
+        // fallback to cpu buffer
+        return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
     }
 
     // FIXME: this is a hack to avoid having to implement a new buffer type
diff --git a/ggml.c b/ggml.c
index 3656422d73767396c3f29c370464c84d5737f1dd..73600ab050ec89627f0991ab8165497d7c9b2754 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -19351,7 +19351,7 @@ void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) {
                             data[j] = ((struct gguf_str *)src->kv[i].value.arr.data)[j].data;
                         }
                         gguf_set_arr_str(ctx, src->kv[i].key.data, data, src->kv[i].value.arr.n);
-                        free(data);
+                        free((void *)data);
                     } else if (src->kv[i].value.arr.type == GGUF_TYPE_ARRAY) {
                         GGML_ASSERT(false && "nested arrays not supported");
                     } else {
diff --git a/ggml.h b/ggml.h
index 338f355a408b3328dfaa4f150edb92e5476f2b4d..67d6bc4f1ef1b2bbfc2ba6bf5e377f8e2c52cb6b 100644 (file)
--- a/ggml.h
+++ b/ggml.h
 #define GGML_UNREACHABLE() GGML_ASSERT(!"statement should not be reached")
 #elif defined(__GNUC__)
 #define GGML_UNREACHABLE() __builtin_unreachable()
+#elif defined(_MSC_VER)
+#define GGML_UNREACHABLE() __assume(0)
 #else
 #define GGML_UNREACHABLE() ((void) 0)
 #endif
index 5699a0fcf3495e02f975bd7882742fd6be7f5477..a24621539f6bd9d048a7d104afe3ca9054dec6c2 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1281,7 +1281,7 @@ struct llama_hparams {
         if (this->rope_finetuned  != other.rope_finetuned)  return true;
         if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
 
-        const float EPSILON = 1e-9;
+        const float EPSILON = 1e-9f;
 
         if (!is_float_close(this->f_norm_eps,            other.f_norm_eps,            EPSILON)) return true;
         if (!is_float_close(this->f_norm_rms_eps,        other.f_norm_rms_eps,        EPSILON)) return true;
@@ -10300,7 +10300,7 @@ int llama_token_to_piece(const struct llama_model * model, llama_token token, ch
                 std::string result = model->vocab.id_to_token[token].text;
                 llama_unescape_whitespace(result);
                 if (length < (int) result.length()) {
-                    return -result.length();
+                    return -(int) result.length();
                 }
                 memcpy(buf, result.c_str(), result.length());
                 return result.length();
@@ -10330,7 +10330,7 @@ int llama_token_to_piece(const struct llama_model * model, llama_token token, ch
                 std::string result = model->vocab.id_to_token[token].text;
                 result = llama_decode_text(result);
                 if (length < (int) result.length()) {
-                    return -result.length();
+                    return -(int) result.length();
                 }
                 memcpy(buf, result.c_str(), result.length());
                 return result.length();
index 14914def565d910a2b100cdf747072a39d9047d8..8ff76c8910c49917518aa0213882a737bba9a30b 100644 (file)
@@ -883,9 +883,6 @@ int main(int argc, const char ** argv) {
             srand(seed);
             const int nargs = 1;
 
-            int64_t ne2[4];
-            ne2[0] = 1;
-
             for (int ndims = 1; ndims <= 2; ++ndims) {
                 x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);