]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
cuda : fix vmm pool with multi GPU (llama/4620)
authorslaren <redacted>
Tue, 26 Dec 2023 20:23:59 +0000 (21:23 +0100)
committerGeorgi Gerganov <redacted>
Wed, 27 Dec 2023 09:40:50 +0000 (11:40 +0200)
* cuda : fix vmm pool with multi GPU

* hip

* use recommended granularity instead of minimum

* better error checking

* fix mixtral

* use cudaMemcpy3DPeerAsync

* use cuda_pool_alloc in ggml_cuda_op_mul_mat

* consolidate error checking in ggml_cuda_set_device

* remove unnecessary inlines

ggml-ci

* style fixes

* only use vmm for the main device

* fix scratch buffer size, re-enable vmm pool for all devices

* remove unnecessary check id != g_main_device

src/ggml-cuda.cu
src/ggml.c

index f32e83ab695b865e1c50802c1802e0807cce8e7c..abad9cc39e2cf051691ed51c0fd49a7434163aa5 100644 (file)
@@ -68,8 +68,9 @@
 #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
 #endif
 #define cudaMemcpy hipMemcpy
-#define cudaMemcpy2DAsync hipMemcpy2DAsync
 #define cudaMemcpyAsync hipMemcpyAsync
+#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
+#define cudaMemcpy2DAsync hipMemcpy2DAsync
 #define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
 #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
 #define cudaMemcpyHostToDevice hipMemcpyHostToDevice
@@ -163,7 +164,7 @@ static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
     const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
 #if __has_builtin(__builtin_elementwise_sub_sat)
     const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
-    return reinterpret_cast<const int&>(c);
+    return reinterpret_cast<const int &>(c);
 #else
     int8x4_t c;
     int16_t tmp;
@@ -174,7 +175,7 @@ static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
         if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
         c[i] = tmp;
     }
-    return reinterpret_cast<int&>(c);
+    return reinterpret_cast<int &>(c);
 #endif // __has_builtin(__builtin_elementwise_sub_sat)
 }
 
@@ -212,6 +213,28 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
 
 static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
 
+[[noreturn]]
+static void ggml_cuda_error(const char * stmt, const char * func, const char * file, const int line, const char * msg) {
+    int id = -1; // in case cudaGetDevice fails
+    cudaGetDevice(&id);
+
+    fprintf(stderr, "CUDA error: %s\n", msg);
+    fprintf(stderr, "  current device: %d, in function %s at %s:%d\n", id, func, file, line);
+    fprintf(stderr, "  %s\n", stmt);
+    // abort with GGML_ASSERT to get a stack trace
+    GGML_ASSERT(!"CUDA error");
+}
+
+#define CUDA_CHECK_GEN(err, success, error_fn)                                      \
+     do {                                                                           \
+        auto err_ = (err);                                                          \
+        if (err_ != (success)) {                                                    \
+            ggml_cuda_error(#err, __func__, __FILE__, __LINE__, error_fn(err_));    \
+        }                                                                           \
+    } while (0)
+
+#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
+
 #if CUDART_VERSION >= 12000
     static const char * cublas_get_error_str(const cublasStatus_t err) {
         return cublasGetStatusString(err);
@@ -233,15 +256,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
     }
 #endif // CUDART_VERSION >= 12000
 
-[[noreturn]]
-static void ggml_cuda_error(const char * stmt, const char * func, const char * file, const int line, const char * msg) {
-    fprintf(stderr, "CUDA error: %s: %s\n", stmt, msg);
-    fprintf(stderr, "  in function %s at %s:%d\n", func, file, line);
-    GGML_ASSERT(!"CUDA error");
-}
-
-#define CUDA_CHECK(err)   do { auto err_ = (err); if (err_ != cudaSuccess)           ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cudaGetErrorString(err_));   } while (0)
-#define CUBLAS_CHECK(err) do { auto err_ = (err); if (err_ != CUBLAS_STATUS_SUCCESS) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cublas_get_error_str(err_)); } while (0)
+#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
 
 #if !defined(GGML_USE_HIPBLAS)
 static const char * cu_get_error_str(CUresult err) {
@@ -249,7 +264,7 @@ static const char * cu_get_error_str(CUresult err) {
     cuGetErrorString(err, &err_str);
     return err_str;
 }
-#define CU_CHECK(err)     do { auto err_ = (err); if (err_ != CUDA_SUCCESS)          ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cu_get_error_str(err_));     } while (0)
+#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
 #endif
 
 #if CUDART_VERSION >= 11100
@@ -306,10 +321,10 @@ typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * s
 typedef void (*ggml_cuda_op_mul_mat_t)(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
     const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
-    const int64_t src1_padded_row_size, const cudaStream_t & stream);
+    const int64_t src1_padded_row_size, cudaStream_t stream);
 typedef void (*ggml_cuda_op_flatten_t)(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream);
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream);
 
 // QK = number of values after dequantization
 // QR = QK / number of values before dequantization
@@ -515,15 +530,15 @@ struct ggml_tensor_extra_gpu {
 
 // this is faster on Windows
 // probably because the Windows CUDA libraries forget to make this check before invoking the drivers
-inline cudaError_t ggml_cuda_set_device(const int device) {
+static void ggml_cuda_set_device(const int device) {
     int current_device;
     CUDA_CHECK(cudaGetDevice(&current_device));
 
     if (device == current_device) {
-        return cudaSuccess;
+        return;
     }
 
-    return cudaSetDevice(device);
+    CUDA_CHECK(cudaSetDevice(device));
 }
 
 static int g_device_count = -1;
@@ -538,7 +553,6 @@ struct cuda_device_capabilities {
 
 static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} };
 
-
 static void * g_scratch_buffer = nullptr;
 static size_t g_scratch_size = 0; // disabled by default
 static size_t g_scratch_offset = 0;
@@ -580,6 +594,7 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
 
 static __device__ __forceinline__ float op_repeat(const float a, const float b) {
     return b;
+    GGML_UNUSED(a);
 }
 
 static __device__ __forceinline__ float op_add(const float a, const float b) {
@@ -701,7 +716,7 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
     dst[i] = x[i] / (1.0f + expf(-x[i]));
 }
 
-static __global__ void gelu_quick_f32(const float *x, float *dst, int k) {
+static __global__ void gelu_quick_f32(const float * x, float * dst, int k) {
     const float GELU_QUICK_COEF = -1.702f;
     const int i  = blockDim.x*blockIdx.x + threadIdx.x;
     if (i >= k) {
@@ -710,7 +725,7 @@ static __global__ void gelu_quick_f32(const float *x, float *dst, int k) {
     dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i])));
 }
 
-static __global__ void tanh_f32(const float *x, float *dst, int k) {
+static __global__ void tanh_f32(const float * x, float * dst, int k) {
     const int i  = blockDim.x*blockIdx.x + threadIdx.x;
     if (i >= k) {
         return;
@@ -727,7 +742,7 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) {
     dst[i] = fmaxf(x[i], 0);
 }
 
-static __global__ void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope) {
+static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
     const int i  = blockDim.x*blockIdx.x + threadIdx.x;
     if (i >= k) {
         return;
@@ -780,7 +795,7 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c
     }
 }
 
-static __global__ void concat_f32(const float  *x,const float  *y, float *dst, const int ne0, const int ne02) {
+static __global__ void concat_f32(const float * x,const float * y, float * dst, const int ne0, const int ne02) {
     int nidx = threadIdx.x + blockIdx.x * blockDim.x;
     if (nidx >= ne0) {
         return;
@@ -805,7 +820,7 @@ static __global__ void concat_f32(const float  *x,const float  *y, float *dst, c
     }
 }
 
-static __global__ void upscale_f32(const float  *x, float *dst, const int ne00, const int nb02, const int scale_factor) {
+static __global__ void upscale_f32(const float * x, float * dst, const int ne00, const int nb02, const int scale_factor) {
     int ne0 = ne00 * scale_factor;
     int nidx = threadIdx.x + blockIdx.x * blockDim.x;
     if (nidx >= ne0) {
@@ -825,7 +840,7 @@ static __global__ void upscale_f32(const float  *x, float *dst, const int ne00,
     dst[offset_dst] = x[offset_src];
 }
 
-static __global__ void pad_f32(const float  *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02) {
+static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02) {
     int nidx = threadIdx.x + blockIdx.x * blockDim.x;
     if (nidx >= ne0) {
         return;
@@ -4727,7 +4742,6 @@ static __global__ void mul_mat_p021_f16_f32(
 
         const int row_y = col_x;
 
-
         // y is not transposed but permuted
         const int iy = channel*nrows_y + row_y;
 
@@ -5402,7 +5416,7 @@ struct bin_bcast_cuda {
             cne[3] = 1;
         };
 
-        auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
+        auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
             cnb[1] *= cne[1];
             cnb[2] *= cne[2];
             cnb[3] *= cne[3];
@@ -6566,18 +6580,16 @@ struct scoped_spin_lock {
 static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
 
 // #define DEBUG_CUDA_MALLOC
-struct cuda_buffer {
+struct ggml_cuda_buffer {
     void * ptr = nullptr;
     size_t size = 0;
 };
 
-static cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS];
+static ggml_cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS];
 static size_t g_cuda_pool_size[GGML_CUDA_MAX_DEVICES] = {0};
 
-static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) {
+static void * ggml_cuda_pool_malloc_leg(int device, size_t size, size_t * actual_size) {
     scoped_spin_lock lock(g_cuda_pool_lock);
-    int id;
-    CUDA_CHECK(cudaGetDevice(&id));
 #ifdef DEBUG_CUDA_MALLOC
     int nnz = 0;
     size_t max_size = 0;
@@ -6585,7 +6597,7 @@ static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) {
     size_t best_diff = 1ull << 36;
     int ibest = -1;
     for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
-        cuda_buffer& b = g_cuda_buffer_pool[id][i];
+        ggml_cuda_buffer& b = g_cuda_buffer_pool[device][i];
         if (b.ptr != nullptr) {
 #ifdef DEBUG_CUDA_MALLOC
             ++nnz;
@@ -6608,7 +6620,7 @@ static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) {
         }
     }
     if (ibest >= 0) {
-        cuda_buffer& b = g_cuda_buffer_pool[id][ibest];
+        ggml_cuda_buffer& b = g_cuda_buffer_pool[device][ibest];
         void * ptr = b.ptr;
         *actual_size = b.size;
         b.ptr = nullptr;
@@ -6618,9 +6630,10 @@ static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) {
     void * ptr;
     size_t look_ahead_size = (size_t) (1.05 * size);
     look_ahead_size = 256 * ((look_ahead_size + 255)/256);
+    ggml_cuda_set_device(device);
     CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
     *actual_size = look_ahead_size;
-    g_cuda_pool_size[id] += look_ahead_size;
+    g_cuda_pool_size[device] += look_ahead_size;
 #ifdef DEBUG_CUDA_MALLOC
     fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz,
             (uint32_t)(max_size/1024/1024), (uint32_t)(g_cuda_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024));
@@ -6628,13 +6641,11 @@ static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) {
     return ptr;
 }
 
-static void ggml_cuda_pool_free_leg(void * ptr, size_t size) {
+static void ggml_cuda_pool_free_leg(int device, void * ptr, size_t size) {
     scoped_spin_lock lock(g_cuda_pool_lock);
-    int id;
-    CUDA_CHECK(cudaGetDevice(&id));
 
     for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
-        cuda_buffer& b = g_cuda_buffer_pool[id][i];
+        ggml_cuda_buffer& b = g_cuda_buffer_pool[device][i];
         if (b.ptr == nullptr) {
             b.ptr = ptr;
             b.size = size;
@@ -6642,73 +6653,73 @@ static void ggml_cuda_pool_free_leg(void * ptr, size_t size) {
         }
     }
     fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
+    ggml_cuda_set_device(device);
     CUDA_CHECK(cudaFree(ptr));
-    g_cuda_pool_size[id] -= size;
+    g_cuda_pool_size[device] -= size;
 }
 
 #if !defined(GGML_USE_HIPBLAS)
 // pool with virtual memory
-static std::vector<CUmemGenericAllocationHandle> g_cuda_pool_handles[GGML_CUDA_MAX_DEVICES];
 static CUdeviceptr g_cuda_pool_addr[GGML_CUDA_MAX_DEVICES] = {0};
 static size_t g_cuda_pool_used[GGML_CUDA_MAX_DEVICES] = {0};
 static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 36; // 64 GB
 
-static void * ggml_cuda_pool_malloc_vmm(size_t size, size_t * actual_size) {
+static void * ggml_cuda_pool_malloc_vmm(int device, size_t size, size_t * actual_size) {
     scoped_spin_lock lock(g_cuda_pool_lock);
-    int id;
-    CUDA_CHECK(cudaGetDevice(&id));
 
     // round up the allocation size to the alignment to ensure that all allocations are aligned for all data types
     const size_t alignment = 128;
     size = alignment * ((size + alignment - 1) / alignment);
 
-    size_t avail = g_cuda_pool_size[id] - g_cuda_pool_used[id];
+    size_t avail = g_cuda_pool_size[device] - g_cuda_pool_used[device];
 
     if (size > avail) {
         // round up to the next multiple of the granularity
         size_t reserve_size = size - avail;
-        const size_t granularity = g_device_caps[id].vmm_granularity;
+        const size_t granularity = g_device_caps[device].vmm_granularity;
         reserve_size = granularity * ((reserve_size + granularity - 1) / granularity);
 
-        GGML_ASSERT(g_cuda_pool_size[id] + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);
+        GGML_ASSERT(g_cuda_pool_size[device] + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);
 
         // allocate more physical memory
         CUmemAllocationProp prop = {};
         prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
         prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
-        prop.location.id = id;
+        prop.location.id = device;
         CUmemGenericAllocationHandle handle;
         CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0));
 
         // reserve virtual address space (if not already reserved)
-        if (g_cuda_pool_addr[id] == 0) {
-            CU_CHECK(cuMemAddressReserve(&g_cuda_pool_addr[id], CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0));
+        if (g_cuda_pool_addr[device] == 0) {
+            CU_CHECK(cuMemAddressReserve(&g_cuda_pool_addr[device], CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0));
         }
 
         // map at the end of the pool
-        CU_CHECK(cuMemMap(g_cuda_pool_addr[id] + g_cuda_pool_size[id], reserve_size, 0, handle, 0));
+        CU_CHECK(cuMemMap(g_cuda_pool_addr[device] + g_cuda_pool_size[device], reserve_size, 0, handle, 0));
+
+        // the memory allocation handle is no longer needed after mapping
+        CU_CHECK(cuMemRelease(handle));
 
         // set access
         CUmemAccessDesc access = {};
         access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
-        access.location.id = id;
+        access.location.id = device;
         access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
-        CU_CHECK(cuMemSetAccess(g_cuda_pool_addr[id] + g_cuda_pool_size[id], reserve_size, &access, 1));
+        CU_CHECK(cuMemSetAccess(g_cuda_pool_addr[device] + g_cuda_pool_size[device], reserve_size, &access, 1));
 
         // add to the pool
-        g_cuda_pool_handles[id].push_back(handle);
-        g_cuda_pool_size[id] += reserve_size;
+        g_cuda_pool_size[device] += reserve_size;
 
         //printf("cuda pool[%d]: size increased to %llu MB (reserved %llu MB)\n",
         //       id, (unsigned long long) (g_cuda_pool_size[id]/1024/1024),
         //       (unsigned long long) (reserve_size/1024/1024));
     }
 
-    GGML_ASSERT(g_cuda_pool_addr[id] != 0);
+    GGML_ASSERT(g_cuda_pool_addr[device] != 0);
 
-    void * ptr = (void *) (g_cuda_pool_addr[id] + g_cuda_pool_used[id]);
+    void * ptr = (void *) (g_cuda_pool_addr[device] + g_cuda_pool_used[device]);
     *actual_size = size;
-    g_cuda_pool_used[id] += size;
+    g_cuda_pool_used[device] += size;
 
 #ifdef DEBUG_CUDA_MALLOC
     printf("cuda pool[%d]: allocated %llu bytes at %llx [%s]\n", id, (unsigned long long) size, ptr);
@@ -6717,38 +6728,32 @@ static void * ggml_cuda_pool_malloc_vmm(size_t size, size_t * actual_size) {
     return ptr;
 }
 
-static void ggml_cuda_pool_free_vmm(void * ptr, size_t size) {
+static void ggml_cuda_pool_free_vmm(int device, void * ptr, size_t size) {
     scoped_spin_lock lock(g_cuda_pool_lock);
-    int id;
-    CUDA_CHECK(cudaGetDevice(&id));
 
 #ifdef DEBUG_CUDA_MALLOC
     printf("cuda pool[%d]: freed %llu bytes at %llx\n", id, (unsigned long long) size, ptr);
 #endif
 
-    g_cuda_pool_used[id] -= size;
+    g_cuda_pool_used[device] -= size;
 
     // all deallocations must be in reverse order of the allocations
-    GGML_ASSERT(ptr == (void *) (g_cuda_pool_addr[id] + g_cuda_pool_used[id]));
+    GGML_ASSERT(ptr == (void *) (g_cuda_pool_addr[device] + g_cuda_pool_used[device]));
 }
 
-static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
-    int id;
-    CUDA_CHECK(cudaGetDevice(&id));
-    if (g_device_caps[id].vmm) {
-        return ggml_cuda_pool_malloc_vmm(size, actual_size);
+static void * ggml_cuda_pool_malloc(int device, size_t size, size_t * actual_size) {
+    if (g_device_caps[device].vmm) {
+        return ggml_cuda_pool_malloc_vmm(device, size, actual_size);
     } else {
-        return ggml_cuda_pool_malloc_leg(size, actual_size);
+        return ggml_cuda_pool_malloc_leg(device, size, actual_size);
     }
 }
 
-static void ggml_cuda_pool_free(void * ptr, size_t size) {
-    int id;
-    CUDA_CHECK(cudaGetDevice(&id));
-    if (g_device_caps[id].vmm) {
-        ggml_cuda_pool_free_vmm(ptr, size);
+static void ggml_cuda_pool_free(int device, void * ptr, size_t size) {
+    if (g_device_caps[device].vmm) {
+        ggml_cuda_pool_free_vmm(device, ptr, size);
     } else {
-        ggml_cuda_pool_free_leg(ptr, size);
+        ggml_cuda_pool_free_leg(device, ptr, size);
     }
 }
 #else
@@ -6758,13 +6763,15 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
 
 template<typename T>
 struct cuda_pool_alloc {
+    int device = -1;
     T * ptr = nullptr;
     size_t actual_size = 0;
 
     // size is in number of elements
     T * alloc(size_t size) {
         GGML_ASSERT(ptr == nullptr);
-        ptr = (T *) ggml_cuda_pool_malloc(size * sizeof(T), &this->actual_size);
+        CUDA_CHECK(cudaGetDevice(&device));
+        ptr = (T *) ggml_cuda_pool_malloc(device, size * sizeof(T), &this->actual_size);
         return ptr;
     }
 
@@ -6774,7 +6781,7 @@ struct cuda_pool_alloc {
 
     ~cuda_pool_alloc() {
         if (ptr != nullptr) {
-            ggml_cuda_pool_free(ptr, actual_size);
+            ggml_cuda_pool_free(device, ptr, actual_size);
         }
     }
 
@@ -6839,7 +6846,7 @@ void ggml_init_cublas() {
                 alloc_prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
                 alloc_prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
                 alloc_prop.location.id = id;
-                CU_CHECK(cuMemGetAllocationGranularity(&g_device_caps[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
+                CU_CHECK(cuMemGetAllocationGranularity(&g_device_caps[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
             }
 #endif // !defined(GGML_USE_HIPBLAS)
             g_device_caps[id].vmm = !!device_vmm;
@@ -6861,7 +6868,7 @@ void ggml_init_cublas() {
         }
 
         for (int id = 0; id < g_device_count; ++id) {
-            CUDA_CHECK(ggml_cuda_set_device(id));
+            ggml_cuda_set_device(id);
 
             // create cuda streams
             for (int is = 0; is < MAX_STREAMS; ++is) {
@@ -6976,7 +6983,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
 
 static void ggml_cuda_op_get_rows(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) {
+    const float * src0_d, const float * src1_d, float * dst_d, cudaStream_t stream) {
 
     GGML_ASSERT(src1->type == GGML_TYPE_I32);
     GGML_ASSERT(dst->type == GGML_TYPE_F32);
@@ -7018,9 +7025,9 @@ static void ggml_cuda_op_get_rows(
 }
 
 template<class op>
-inline void ggml_cuda_op_bin_bcast(
+static void ggml_cuda_op_bin_bcast(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
 
@@ -7039,7 +7046,7 @@ inline void ggml_cuda_op_bin_bcast(
 
 static void ggml_cuda_op_repeat(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & main_stream) {
+    const float * src0_d, const float * src1_d, float * dst_d, cudaStream_t main_stream) {
 
     ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
 
@@ -7047,16 +7054,16 @@ static void ggml_cuda_op_repeat(
     (void) src1_d;
 }
 
-inline void ggml_cuda_op_add(
+static void ggml_cuda_op_add(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
 }
 
-inline void ggml_cuda_op_acc(
+static void ggml_cuda_op_acc(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -7073,23 +7080,23 @@ inline void ggml_cuda_op_acc(
     (void) dst;
 }
 
-inline void ggml_cuda_op_mul(
+static void ggml_cuda_op_mul(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
 }
 
-inline void ggml_cuda_op_div(
+static void ggml_cuda_op_div(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
 }
 
-inline void ggml_cuda_op_gelu(
+static void ggml_cuda_op_gelu(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -7101,9 +7108,9 @@ inline void ggml_cuda_op_gelu(
     (void) src1_dd;
 }
 
-inline void ggml_cuda_op_silu(
+static void ggml_cuda_op_silu(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -7115,9 +7122,9 @@ inline void ggml_cuda_op_silu(
     (void) src1_dd;
 }
 
-inline void ggml_cuda_op_gelu_quick(
+static void ggml_cuda_op_gelu_quick(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -7129,9 +7136,9 @@ inline void ggml_cuda_op_gelu_quick(
     (void) src1_dd;
 }
 
-inline void ggml_cuda_op_tanh(
+static void ggml_cuda_op_tanh(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -7143,9 +7150,9 @@ inline void ggml_cuda_op_tanh(
     (void) src1_dd;
 }
 
-inline void ggml_cuda_op_relu(
+static void ggml_cuda_op_relu(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -7157,9 +7164,9 @@ inline void ggml_cuda_op_relu(
     (void) src1_dd;
 }
 
-inline void ggml_cuda_op_leaky_relu(
+static void ggml_cuda_op_leaky_relu(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -7174,9 +7181,9 @@ inline void ggml_cuda_op_leaky_relu(
     (void) src1_dd;
 }
 
-inline void ggml_cuda_op_sqr(
+static void ggml_cuda_op_sqr(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -7188,9 +7195,9 @@ inline void ggml_cuda_op_sqr(
     (void) src1_dd;
 }
 
-inline void ggml_cuda_op_norm(
+static void ggml_cuda_op_norm(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -7208,10 +7215,9 @@ inline void ggml_cuda_op_norm(
     (void) src1_dd;
 }
 
-
-inline void ggml_cuda_op_group_norm(
+static void ggml_cuda_op_group_norm(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -7225,9 +7231,9 @@ inline void ggml_cuda_op_group_norm(
     (void) src1_dd;
 }
 
-inline void ggml_cuda_op_concat(
+static void ggml_cuda_op_concat(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -7241,9 +7247,9 @@ inline void ggml_cuda_op_concat(
     (void) dst;
 }
 
-inline void ggml_cuda_op_upscale(
+static void ggml_cuda_op_upscale(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->type == GGML_TYPE_F32);
@@ -7258,9 +7264,9 @@ inline void ggml_cuda_op_upscale(
     (void) src1_dd;
 }
 
-inline void ggml_cuda_op_pad(
+static void ggml_cuda_op_pad(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->type == GGML_TYPE_F32);
@@ -7275,9 +7281,9 @@ inline void ggml_cuda_op_pad(
     (void) src1_dd;
 }
 
-inline void ggml_cuda_op_rms_norm(
+static void ggml_cuda_op_rms_norm(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -7295,10 +7301,10 @@ inline void ggml_cuda_op_rms_norm(
     (void) src1_dd;
 }
 
-inline void ggml_cuda_op_mul_mat_q(
+static void ggml_cuda_op_mul_mat_q(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
     const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
-    const int64_t src1_padded_row_size, const cudaStream_t & stream) {
+    const int64_t src1_padded_row_size, cudaStream_t stream) {
 
     const int64_t ne00 = src0->ne[0];
 
@@ -7360,7 +7366,7 @@ inline void ggml_cuda_op_mul_mat_q(
 static int64_t get_row_rounding(ggml_type type) {
     int64_t min_compute_capability = INT_MAX;
     int64_t max_compute_capability = INT_MIN;
-    for (int64_t id = 0; id < g_device_count; ++id) {
+    for (int id = 0; id < g_device_count; ++id) {
         if (g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
             if (min_compute_capability > g_device_caps[id].cc) {
                 min_compute_capability = g_device_caps[id].cc;
@@ -7418,10 +7424,10 @@ static int64_t get_row_rounding(ggml_type type) {
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 }
 
-inline void ggml_cuda_op_mul_mat_vec_q(
+static void ggml_cuda_op_mul_mat_vec_q(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
     const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
-    const int64_t src1_padded_row_size, const cudaStream_t & stream) {
+    const int64_t src1_padded_row_size, cudaStream_t stream) {
 
     GGML_ASSERT(ggml_nrows(src1) == 1);
 
@@ -7471,10 +7477,10 @@ inline void ggml_cuda_op_mul_mat_vec_q(
     (void) src1_padded_row_size;
 }
 
-inline void ggml_cuda_op_dequantize_mul_mat_vec(
+static void ggml_cuda_op_dequantize_mul_mat_vec(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
     const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
-    const int64_t src1_padded_row_size, const cudaStream_t & stream) {
+    const int64_t src1_padded_row_size, cudaStream_t stream) {
 
     const int64_t ne00 = src0->ne[0];
     const int64_t row_diff = row_high - row_low;
@@ -7545,10 +7551,10 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
     (void) src1_padded_row_size;
 }
 
-inline void ggml_cuda_op_mul_mat_cublas(
+static void ggml_cuda_op_mul_mat_cublas(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
     const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
-    const int64_t src1_padded_row_size, const cudaStream_t & stream) {
+    const int64_t src1_padded_row_size, cudaStream_t stream) {
 
     GGML_ASSERT(src0_dd_i  != nullptr);
     GGML_ASSERT(src1_ddf_i != nullptr);
@@ -7637,9 +7643,9 @@ inline void ggml_cuda_op_mul_mat_cublas(
     (void) src1_padded_row_size;
 }
 
-inline void ggml_cuda_op_rope(
+static void ggml_cuda_op_rope(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
     GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
@@ -7717,9 +7723,9 @@ inline void ggml_cuda_op_rope(
     (void) src1_dd;
 }
 
-inline void ggml_cuda_op_alibi(
+static void ggml_cuda_op_alibi(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -7748,9 +7754,9 @@ inline void ggml_cuda_op_alibi(
     (void) src1_dd;
 }
 
-inline void ggml_cuda_op_im2col(
+static void ggml_cuda_op_im2col(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -7783,10 +7789,9 @@ inline void ggml_cuda_op_im2col(
     (void) src0_dd;
 }
 
-
-inline void ggml_cuda_op_sum_rows(
+static void ggml_cuda_op_sum_rows(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -7801,9 +7806,9 @@ inline void ggml_cuda_op_sum_rows(
     (void) src1_dd;
 }
 
-inline void ggml_cuda_op_argsort(
+static void ggml_cuda_op_argsort(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_I32);
@@ -7820,9 +7825,9 @@ inline void ggml_cuda_op_argsort(
     (void) src1_dd;
 }
 
-inline void ggml_cuda_op_diag_mask_inf(
+static void ggml_cuda_op_diag_mask_inf(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -7840,9 +7845,9 @@ inline void ggml_cuda_op_diag_mask_inf(
     (void) src1_dd;
 }
 
-inline void ggml_cuda_op_soft_max(
+static void ggml_cuda_op_soft_max(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -7861,9 +7866,9 @@ inline void ggml_cuda_op_soft_max(
     (void) dst;
 }
 
-inline void ggml_cuda_op_scale(
+static void ggml_cuda_op_scale(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -7879,9 +7884,9 @@ inline void ggml_cuda_op_scale(
     (void) src1_dd;
 }
 
-inline void ggml_cuda_op_clamp(
+static void ggml_cuda_op_clamp(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -7974,12 +7979,12 @@ static void ggml_cuda_set_peer_access(const int n_tokens) {
 
 #ifdef NDEBUG
     for (int id = 0; id < g_device_count; ++id) {
-        CUDA_CHECK(ggml_cuda_set_device(id));
+        ggml_cuda_set_device(id);
         CUDA_CHECK(cudaDeviceSynchronize());
     }
 
     for (int id = 0; id < g_device_count; ++id) {
-        CUDA_CHECK(ggml_cuda_set_device(id));
+        ggml_cuda_set_device(id);
 
         for (int id_other = 0; id_other < g_device_count; ++id_other) {
             if (id == id_other) {
@@ -8013,7 +8018,6 @@ static void ggml_cuda_op_mul_mat(
     const int64_t ne01 = src0->ne[1];
     const int64_t ne02 = src0->ne[2];
     const int64_t ne03 = src0->ne[3];
-    const int64_t nrows0 = ggml_nrows(src0);
 
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
@@ -8056,27 +8060,29 @@ static void ggml_cuda_op_mul_mat(
     GGML_ASSERT(!(split && ne03 > 1));
     GGML_ASSERT(!(split && ne02 < ne12));
 
-    // dd = data device
-    char  *  src0_dd[GGML_CUDA_MAX_DEVICES] = {nullptr};
-    float * src1_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; // float
-    char  * src1_ddq[GGML_CUDA_MAX_DEVICES] = {nullptr}; // q8_1
-    float *   dst_dd[GGML_CUDA_MAX_DEVICES] = {nullptr};
+    struct dev_data {
+        cuda_pool_alloc<char>  src0_dd_alloc;
+        cuda_pool_alloc<float> src1_ddf_alloc;
+        cuda_pool_alloc<char>  src1_ddq_alloc;
+        cuda_pool_alloc<float>   dst_dd_alloc;
 
-    // as = actual size
-    size_t  src0_as[GGML_CUDA_MAX_DEVICES] = {0};
-    size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
-    size_t src1_asq[GGML_CUDA_MAX_DEVICES] = {0};
-    size_t   dst_as[GGML_CUDA_MAX_DEVICES] = {0};
+        char  *  src0_dd = nullptr;
+        float * src1_ddf = nullptr; // float
+        char  * src1_ddq = nullptr; // q8_1
+        float *   dst_dd = nullptr;
 
-    int64_t  row_low[GGML_CUDA_MAX_DEVICES];
-    int64_t row_high[GGML_CUDA_MAX_DEVICES];
+        int64_t  row_low;
+        int64_t row_high;
+    };
+
+    dev_data dev[GGML_CUDA_MAX_DEVICES];
 
     int used_devices = 0;
 
-    for (int64_t id = 0; id < g_device_count; ++id) {
+    for (int id = 0; id < g_device_count; ++id) {
         // by default, use all rows
-        row_low[id]  = 0;
-        row_high[id] = ne01;
+        dev[id].row_low  = 0;
+        dev[id].row_high = ne01;
 
         // for multi GPU, get the row boundaries from tensor split
         // and round to mul_mat_q tile sizes
@@ -8084,23 +8090,23 @@ static void ggml_cuda_op_mul_mat(
             const int64_t rounding = get_row_rounding(src0->type);
 
             if (id != 0) {
-                row_low[id]  = ne01*g_tensor_split[id];
-                if (row_low[id] < ne01) {
-                    row_low[id] -= row_low[id] % rounding;
+                dev[id].row_low  = ne01*g_tensor_split[id];
+                if (dev[id].row_low < ne01) {
+                    dev[id].row_low -= dev[id].row_low % rounding;
                 }
             }
 
             if (id != g_device_count - 1) {
-                row_high[id]  = ne01*g_tensor_split[id + 1];
-                if (row_high[id] < ne01) {
-                    row_high[id] -= row_high[id] % rounding;
+                dev[id].row_high  = ne01*g_tensor_split[id + 1];
+                if (dev[id].row_high < ne01) {
+                    dev[id].row_high -= dev[id].row_high % rounding;
                 }
             }
         }
     }
 
-    for (int64_t id = 0; id < g_device_count; ++id) {
-        if ((!split && id != g_main_device) || row_low[id] == row_high[id]) {
+    for (int id = 0; id < g_device_count; ++id) {
+        if ((!split && id != g_main_device) || dev[id].row_low == dev[id].row_high) {
             continue;
         }
 
@@ -8110,42 +8116,41 @@ static void ggml_cuda_op_mul_mat(
         const bool  dst_on_device =  dst->backend == GGML_BACKEND_GPU && id == g_main_device;
 
         ggml_cuda_set_device(id);
-        const cudaStream_t stream = g_cudaStreams[id][0];
+        cudaStream_t stream = g_cudaStreams[id][0];
 
         if (src0_on_device && src0_is_contiguous) {
-            src0_dd[id] = (char *) src0_extra->data_device[id];
+            dev[id].src0_dd = (char *) src0_extra->data_device[id];
         } else {
-            // const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
-            src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]);
+            dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ggml_nbytes(src0));
         }
 
         if (src1_on_device && src1_is_contiguous) {
-            src1_ddf[id] = (float *) src1_extra->data_device[id];
+            dev[id].src1_ddf = (float *) src1_extra->data_device[id];
         } else {
-            src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]);
+            dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(ggml_nelements(src1));
         }
 
         if (convert_src1_to_q8_1) {
-            src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]);
+            dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
 
             if (src1_on_device && src1_is_contiguous) {
-                quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream);
+                quantize_row_q8_1_cuda(dev[id].src1_ddf, dev[id].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
                 CUDA_CHECK(cudaGetLastError());
             }
         }
 
         if (dst_on_device) {
-            dst_dd[id] = (float *) dst_extra->data_device[id];
+            dev[id].dst_dd = (float *) dst_extra->data_device[id];
         } else {
-            const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst);
-            dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]);
+            const size_t size_dst_ddf = split ? (dev[id].row_high - dev[id].row_low)*ne1 : ggml_nelements(dst);
+            dev[id].dst_dd = dev[id].dst_dd_alloc.alloc(size_dst_ddf);
         }
     }
 
     // if multiple devices are used they need to wait for the main device
     // here an event is recorded that signals that the main device has finished calculating the input data
     if (split && used_devices > 1) {
-        CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+        ggml_cuda_set_device(g_main_device);
         CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][0], g_cudaStreams[g_main_device][0]));
     }
 
@@ -8154,17 +8159,17 @@ static void ggml_cuda_op_mul_mat(
         const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0;
         const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
 
-        for (int64_t id = 0; id < g_device_count; ++id) {
-            if ((!split && id != g_main_device) || row_low[id] == row_high[id]) {
+        for (int id = 0; id < g_device_count; ++id) {
+            if ((!split && id != g_main_device) || dev[id].row_low == dev[id].row_high) {
                 continue;
             }
 
             const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
             const bool  dst_on_device =  dst->backend == GGML_BACKEND_GPU && id == g_main_device;
-            const int64_t row_diff = row_high[id] - row_low[id];
+            const int64_t row_diff = dev[id].row_high - dev[id].row_low;
 
             ggml_cuda_set_device(id);
-            const cudaStream_t stream = g_cudaStreams[id][is];
+            cudaStream_t stream = g_cudaStreams[id][is];
 
             // wait for main GPU data if necessary
             if (split && (id != g_main_device || is != 0)) {
@@ -8178,34 +8183,34 @@ static void ggml_cuda_op_mul_mat(
                 const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
 
                 // for split tensors the data begins at i0 == i0_offset_low
-                char  *  src0_dd_i =  src0_dd[id] + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
-                float * src1_ddf_i = src1_ddf[id] + (i0*ne11 + src1_col_0) * ne10;
-                char  * src1_ddq_i = src1_ddq[id] +  src1_ddq_i_offset;
-                float *   dst_dd_i =   dst_dd[id] + (i0*ne1  + src1_col_0) * (dst_on_device ? ne0 : row_diff);
+                char  *  src0_dd_i =  dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
+                float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10;
+                char  * src1_ddq_i = dev[id].src1_ddq +  src1_ddq_i_offset;
+                float *   dst_dd_i =   dev[id].dst_dd + (i0*ne1  + src1_col_0) * (dst_on_device ? ne0 : row_diff);
 
                 // the main device memory buffer can be on VRAM scratch, with space for all partial results
                 // in that case an offset on dst_ddf_i is needed
                 if (dst->backend == GGML_BACKEND_GPU && id == g_main_device) {
-                    dst_dd_i += row_low[id]; // offset is 0 if no tensor split
+                    dst_dd_i += dev[id].row_low; // offset is 0 if no tensor split
                 }
 
                 // copy src0, src1 to device if necessary
                 if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
                     if (id != g_main_device) {
                         if (convert_src1_to_q8_1) {
-                            char * src1_ddq_i_source = src1_ddq[g_main_device] + src1_ddq_i_offset;
-                            CUDA_CHECK(cudaMemcpyAsync(src1_ddq_i, src1_ddq_i_source, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs,
-                                                    cudaMemcpyDeviceToDevice, stream));
+                            char * src1_ddq_i_source = dev[g_main_device].src1_ddq + src1_ddq_i_offset;
+                            CUDA_CHECK(cudaMemcpyPeerAsync(src1_ddq_i, id, src1_ddq_i_source, g_main_device,
+                                                            src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream));
                         } else {
                             float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device];
                             src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
-                            CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_ncols*ne10*sizeof(float),
-                                                    cudaMemcpyDeviceToDevice, stream));
+                            CUDA_CHECK(cudaMemcpyPeerAsync(src1_ddf_i, id, src1_ddf_i_source, g_main_device,
+                                                            src1_ncols*ne10*sizeof(float), stream));
                         }
                     }
                 } else if (src1->backend == GGML_BACKEND_CPU || (src1_on_device && !src1_is_contiguous)) {
                     CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
-                                   src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
+                                src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
                 } else {
                     GGML_ASSERT(false);
                 }
@@ -8216,12 +8221,12 @@ static void ggml_cuda_op_mul_mat(
                 }
 
                 if (src1_col_0 == 0 && (!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) {
-                    CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, row_low[id], row_high[id], stream));
+                    CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream));
                 }
 
                 // do the computation
                 op(src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
-                   row_low[id], row_high[id], src1_ncols, src1_padded_col_size, stream);
+                    dev[id].row_low, dev[id].row_high, src1_ncols, src1_padded_col_size, stream);
                 CUDA_CHECK(cudaGetLastError());
 
                 // copy dst to host or other device if necessary
@@ -8245,9 +8250,25 @@ static void ggml_cuda_op_mul_mat(
                         // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.
                         float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
                         GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
-                        dhf_dst_i += src1_col_0*ne0 + row_low[id];
-                        CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float), dst_dd_i, row_diff*sizeof(float),
-                                                    row_diff*sizeof(float), src1_ncols, kind, stream));
+                        dhf_dst_i += src1_col_0*ne0 + dev[id].row_low;
+#if !defined(GGML_USE_HIPBLAS)
+                        if (kind == cudaMemcpyDeviceToDevice) {
+                            // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
+                            cudaMemcpy3DPeerParms p = {};
+                            p.dstDevice = g_main_device;
+                            p.dstPtr = make_cudaPitchedPtr(dhf_dst_i, ne0*sizeof(float), row_diff, src1_ncols);
+                            p.srcDevice = id;
+                            p.srcPtr = make_cudaPitchedPtr(dst_dd_i, row_diff*sizeof(float), row_diff, src1_ncols);
+                            p.extent = make_cudaExtent(row_diff*sizeof(float), src1_ncols, 1);
+                            CUDA_CHECK(cudaMemcpy3DPeerAsync(&p, stream));
+                        } else
+#endif
+                        {
+                            CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float),
+                                                            dst_dd_i, row_diff*sizeof(float),
+                                                            row_diff*sizeof(float), src1_ncols,
+                                                            kind, stream));
+                        }
                     } else {
                         float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
                         GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
@@ -8264,35 +8285,14 @@ static void ggml_cuda_op_mul_mat(
         }
     }
 
-    for (int64_t id = 0; id < g_device_count; ++id) {
-        if ((!split && id != g_main_device) || row_low[id] == row_high[id]) {
-            continue;
-        }
-        CUDA_CHECK(ggml_cuda_set_device(id));
-
-        // free buffers again when done
-        if (dst_as[id] > 0) {
-            ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
-        }
-        if (src1_asq[id] > 0) {
-            ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]);
-        }
-        if (src1_asf[id] > 0) {
-            ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
-        }
-        if (src0_as[id] > 0) {
-            ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
-        }
-    }
-
     // main device waits for all other devices to be finished
     if (split && g_device_count > 1) {
         int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
         is_max = is_max <= MAX_STREAMS ? is_max : MAX_STREAMS;
 
-        CUDA_CHECK(ggml_cuda_set_device(g_main_device));
-        for (int64_t id = 0; id < g_device_count; ++id) {
-            if (row_low[id] == row_high[id]) {
+        ggml_cuda_set_device(g_main_device);
+        for (int id = 0; id < g_device_count; ++id) {
+            if (dev[id].row_low == dev[id].row_high) {
                 continue;
             }
             for (int64_t is = 0; is < is_max; ++is) {
@@ -8302,7 +8302,7 @@ static void ggml_cuda_op_mul_mat(
     }
 
     if (dst->backend == GGML_BACKEND_CPU) {
-        CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+        ggml_cuda_set_device(g_main_device);
         CUDA_CHECK(cudaDeviceSynchronize());
     }
 }
@@ -8412,7 +8412,7 @@ static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tens
 
     const int64_t ne12 = src1->ne[2];
 
-    CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+    ggml_cuda_set_device(g_main_device);
     cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
 
     ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
@@ -8444,7 +8444,7 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
 
     const int64_t ne12 = src1->ne[2];
 
-    CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+    ggml_cuda_set_device(g_main_device);
     cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
 
     ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
@@ -8515,7 +8515,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
     const int64_t ne1 = ggml_nelements(src1);
     const int64_t ne  = ggml_nelements(dst);
 
-    CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+    ggml_cuda_set_device(g_main_device);
     cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
 
     CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream));
@@ -8656,7 +8656,7 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
     const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
 
     int64_t min_compute_capability = INT_MAX;
-    for (int64_t id = 0; id < g_device_count; ++id) {
+    for (int id = 0; id < g_device_count; ++id) {
         if (min_compute_capability > g_device_caps[id].cc && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
             min_compute_capability = g_device_caps[id].cc;
         }
@@ -8799,7 +8799,7 @@ static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
     const int64_t ne1 = ggml_nelements(src1);
     const int64_t ne  = ggml_nelements(dst);
 
-    CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+    ggml_cuda_set_device(g_main_device);
     cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
 
     CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream));
@@ -8917,7 +8917,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
 
     std::vector<char> ids_host(ggml_nbytes(ids));
 
-    const cudaStream_t stream = g_cudaStreams[g_main_device][0];
+    cudaStream_t stream = g_cudaStreams[g_main_device][0];
 
     if (ids->backend == GGML_BACKEND_GPU) {
         const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
@@ -9073,7 +9073,7 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
     const int64_t nb11 = src1->nb[1];
     const int64_t nb12 = src1->nb[2];
 
-    CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+    ggml_cuda_set_device(g_main_device);
     cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
 
     const ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
@@ -9163,7 +9163,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
     ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu;
     memset(extra, 0, sizeof(*extra));
 
-    for (int64_t id = 0; id < g_device_count; ++id) {
+    for (int id = 0; id < g_device_count; ++id) {
         if (backend == GGML_BACKEND_GPU && id != g_main_device) {
             continue;
         }
@@ -9234,15 +9234,14 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
 
     ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
 
-    for (int64_t id = 0; id < g_device_count; ++id) {
+    for (int id = 0; id < g_device_count; ++id) {
+        ggml_cuda_set_device(id);
         if (extra->data_device[id] != nullptr) {
-            CUDA_CHECK(ggml_cuda_set_device(id));
             CUDA_CHECK(cudaFree(extra->data_device[id]));
         }
 
         for (int64_t is = 0; is < MAX_STREAMS; ++is) {
             if (extra->events[id][is] != nullptr) {
-                CUDA_CHECK(ggml_cuda_set_device(id));
                 CUDA_CHECK(cudaEventDestroy(extra->events[id][is]));
             }
         }
@@ -9296,7 +9295,7 @@ static void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scra
         force_inplace;
     const size_t size = ggml_nbytes(tensor);
 
-    CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+    ggml_cuda_set_device(g_main_device);
     if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) {
         ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
         char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
@@ -9373,7 +9372,7 @@ void ggml_cuda_copy_to_device(struct ggml_tensor * tensor) {
     GGML_ASSERT(ggml_is_contiguous(tensor));
 
     ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
-    CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+    ggml_cuda_set_device(g_main_device);
     CUDA_CHECK(cudaMemcpy(extra->data_device[g_main_device], tensor->data, ggml_nbytes(tensor), cudaMemcpyHostToDevice));
 }
 
index d2456048031e202cc17a6b63d4f4b7cdf066663d..ed56e60a8893efa2c9e88b627f87728f65dac89c 100644 (file)
@@ -4041,7 +4041,6 @@ static struct ggml_tensor * ggml_group_norm_impl(
     result->op = GGML_OP_GROUP_NORM;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL; // TODO: maybe store epsilon here?
 
     return result;
 }
@@ -5541,7 +5540,6 @@ static struct ggml_tensor * ggml_upscale_impl(
     result->op_params[0] = scale_factor;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5846,7 +5844,6 @@ struct ggml_tensor * ggml_get_rel_pos(
     result->op   = GGML_OP_GET_REL_POS;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }