#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess
#define __trap abort
+#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
+#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
+#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
+#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
+#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
+#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
+#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
+#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
+#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
#else
#include <cuda_runtime.h>
+#include <cuda.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
-// CUDA 10.2 does not have these macro definitions.
-#ifndef CUBLAS_TF32_TENSOR_OP_MATH
+
+#if CUDART_VERSION < 11020
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
#define CUBLAS_COMPUTE_16F CUDA_R_16F
#define CUBLAS_COMPUTE_32F CUDA_R_32F
#define cublasComputeType_t cudaDataType_t
-#endif
+#endif // CUDART_VERSION < 11020
+
#endif // defined(GGML_USE_HIPBLAS)
#include "ggml-cuda.h"
static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
-#define CUDA_CHECK(err) \
- do { \
- cudaError_t err_ = (err); \
- if (err_ != cudaSuccess) { \
- int id; \
- cudaGetDevice(&id); \
- fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
- cudaGetErrorString(err_)); \
- fprintf(stderr, "current device: %d\n", id); \
- GGML_ASSERT(!"CUDA error"); \
- } \
- } while (0)
-
#if CUDART_VERSION >= 12000
-#define CUBLAS_CHECK(err) \
- do { \
- cublasStatus_t err_ = (err); \
- if (err_ != CUBLAS_STATUS_SUCCESS) { \
- int id; \
- cudaGetDevice(&id); \
- fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \
- err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \
- fprintf(stderr, "current device: %d\n", id); \
- GGML_ASSERT(!"cuBLAS error"); \
- } \
- } while (0)
+ static const char * cublas_get_error_str(const cublasStatus_t err) {
+ return cublasGetStatusString(err);
+ }
#else
-#define CUBLAS_CHECK(err) \
- do { \
- cublasStatus_t err_ = (err); \
- if (err_ != CUBLAS_STATUS_SUCCESS) { \
- int id; \
- cudaGetDevice(&id); \
- fprintf(stderr, "\ncuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
- fprintf(stderr, "current device: %d\n", id); \
- GGML_ASSERT(!"cuBLAS error"); \
- } \
- } while (0)
-#endif // CUDART_VERSION >= 11
+ static const char * cublas_get_error_str(const cublasStatus_t err) {
+ switch (err) {
+ case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
+ case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
+ case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
+ case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
+ case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
+ case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
+ case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
+ case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
+ case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
+ default: return "unknown error";
+ }
+ }
+#endif // CUDART_VERSION >= 12000
+
+[[noreturn]]
+static void ggml_cuda_error(const char * stmt, const char * func, const char * file, const int line, const char * msg) {
+ fprintf(stderr, "CUDA error: %s: %s\n", stmt, msg);
+ fprintf(stderr, " in function %s at %s:%d\n", func, file, line);
+ GGML_ASSERT(!"CUDA error");
+}
+
+#define CUDA_CHECK(err) do { auto err_ = (err); if (err_ != cudaSuccess) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cudaGetErrorString(err_)); } while (0)
+#define CUBLAS_CHECK(err) do { auto err_ = (err); if (err_ != CUBLAS_STATUS_SUCCESS) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cublas_get_error_str(err_)); } while (0)
+
+#if !defined(GGML_USE_HIPBLAS)
+static const char * cu_get_error_str(CUresult err) {
+ const char * err_str;
+ cuGetErrorString(err, &err_str);
+ return err_str;
+}
+#define CU_CHECK(err) do { auto err_ = (err); if (err_ != CUDA_SUCCESS) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cu_get_error_str(err_)); } while (0)
+#endif
#if CUDART_VERSION >= 11100
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
static int g_device_count = -1;
static int g_main_device = 0;
-static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
+struct cuda_device_capabilities {
+ int cc; // compute capability
+ bool vmm; // virtual memory support
+ size_t vmm_granularity; // granularity of virtual memory
+};
+
+static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} };
+
+
static void * g_scratch_buffer = nullptr;
static size_t g_scratch_size = 0; // disabled by default
static size_t g_scratch_offset = 0;
int id;
CUDA_CHECK(cudaGetDevice(&id));
- const int compute_capability = g_compute_capabilities[id];
+ const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) {
int id;
CUDA_CHECK(cudaGetDevice(&id));
- const int compute_capability = g_compute_capabilities[id];
+ const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) {
int id;
CUDA_CHECK(cudaGetDevice(&id));
- const int compute_capability = g_compute_capabilities[id];
+ const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) {
int id;
CUDA_CHECK(cudaGetDevice(&id));
- const int compute_capability = g_compute_capabilities[id];
+ const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) {
int id;
CUDA_CHECK(cudaGetDevice(&id));
- const int compute_capability = g_compute_capabilities[id];
+ const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) {
int id;
CUDA_CHECK(cudaGetDevice(&id));
- const int compute_capability = g_compute_capabilities[id];
+ const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) {
int id;
CUDA_CHECK(cudaGetDevice(&id));
- const int compute_capability = g_compute_capabilities[id];
+ const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) {
int id;
CUDA_CHECK(cudaGetDevice(&id));
- const int compute_capability = g_compute_capabilities[id];
+ const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) {
int id;
CUDA_CHECK(cudaGetDevice(&id));
- const int compute_capability = g_compute_capabilities[id];
+ const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) {
int id;
CUDA_CHECK(cudaGetDevice(&id));
- const int compute_capability = g_compute_capabilities[id];
+ const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) {
scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
};
+static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
+
+// #define DEBUG_CUDA_MALLOC
struct cuda_buffer {
void * ptr = nullptr;
size_t size = 0;
};
static cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS];
-static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
+static size_t g_cuda_pool_size[GGML_CUDA_MAX_DEVICES] = {0};
-static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
+static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) {
scoped_spin_lock lock(g_cuda_pool_lock);
int id;
CUDA_CHECK(cudaGetDevice(&id));
#ifdef DEBUG_CUDA_MALLOC
int nnz = 0;
- size_t max_size = 0, tot_size = 0;
+ size_t max_size = 0;
#endif
size_t best_diff = 1ull << 36;
int ibest = -1;
if (b.ptr != nullptr) {
#ifdef DEBUG_CUDA_MALLOC
++nnz;
- tot_size += b.size;
if (b.size > max_size) max_size = b.size;
#endif
if (b.size >= size) {
b.size = 0;
return ptr;
}
-#ifdef DEBUG_CUDA_MALLOC
- fprintf(stderr, "%s: %d buffers, max_size = %u MB, tot_size = %u MB, requested %u MB\n", __func__, nnz,
- (uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024));
-#endif
void * ptr;
size_t look_ahead_size = (size_t) (1.05 * size);
look_ahead_size = 256 * ((look_ahead_size + 255)/256);
CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
*actual_size = look_ahead_size;
+ g_cuda_pool_size[id] += look_ahead_size;
+#ifdef DEBUG_CUDA_MALLOC
+ fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz,
+ (uint32_t)(max_size/1024/1024), (uint32_t)(g_cuda_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024));
+#endif
return ptr;
}
-static void ggml_cuda_pool_free(void * ptr, size_t size) {
+static void ggml_cuda_pool_free_leg(void * ptr, size_t size) {
scoped_spin_lock lock(g_cuda_pool_lock);
int id;
CUDA_CHECK(cudaGetDevice(&id));
}
fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
CUDA_CHECK(cudaFree(ptr));
+ g_cuda_pool_size[id] -= size;
+}
+
+#if !defined(GGML_USE_HIPBLAS)
+// pool with virtual memory
+static std::vector<CUmemGenericAllocationHandle> g_cuda_pool_handles[GGML_CUDA_MAX_DEVICES];
+static CUdeviceptr g_cuda_pool_addr[GGML_CUDA_MAX_DEVICES] = {0};
+static size_t g_cuda_pool_used[GGML_CUDA_MAX_DEVICES] = {0};
+static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 36; // 64 GB
+
+static void * ggml_cuda_pool_malloc_vmm(size_t size, size_t * actual_size) {
+ scoped_spin_lock lock(g_cuda_pool_lock);
+ int id;
+ CUDA_CHECK(cudaGetDevice(&id));
+
+ // round up the allocation size to the alignment to ensure that all allocations are aligned for all data types
+ const size_t alignment = 128;
+ size = alignment * ((size + alignment - 1) / alignment);
+
+ size_t avail = g_cuda_pool_size[id] - g_cuda_pool_used[id];
+
+ if (size > avail) {
+ // round up to the next multiple of the granularity
+ size_t reserve_size = size - avail;
+ const size_t granularity = g_device_caps[id].vmm_granularity;
+ reserve_size = granularity * ((reserve_size + granularity - 1) / granularity);
+
+ GGML_ASSERT(g_cuda_pool_size[id] + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);
+
+ // allocate more physical memory
+ CUmemAllocationProp prop = {};
+ prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
+ prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
+ prop.location.id = id;
+ CUmemGenericAllocationHandle handle;
+ CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0));
+
+ // reserve virtual address space (if not already reserved)
+ if (g_cuda_pool_addr[id] == 0) {
+ CU_CHECK(cuMemAddressReserve(&g_cuda_pool_addr[id], CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0));
+ }
+
+ // map at the end of the pool
+ CU_CHECK(cuMemMap(g_cuda_pool_addr[id] + g_cuda_pool_size[id], reserve_size, 0, handle, 0));
+
+ // set access
+ CUmemAccessDesc access = {};
+ access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
+ access.location.id = id;
+ access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
+ CU_CHECK(cuMemSetAccess(g_cuda_pool_addr[id] + g_cuda_pool_size[id], reserve_size, &access, 1));
+
+ // add to the pool
+ g_cuda_pool_handles[id].push_back(handle);
+ g_cuda_pool_size[id] += reserve_size;
+
+ //printf("cuda pool[%d]: size increased to %llu MB (reserved %llu MB)\n",
+ // id, (unsigned long long) (g_cuda_pool_size[id]/1024/1024),
+ // (unsigned long long) (reserve_size/1024/1024));
+ }
+
+ GGML_ASSERT(g_cuda_pool_addr[id] != 0);
+
+ void * ptr = (void *) (g_cuda_pool_addr[id] + g_cuda_pool_used[id]);
+ *actual_size = size;
+ g_cuda_pool_used[id] += size;
+
+#ifdef DEBUG_CUDA_MALLOC
+ printf("cuda pool[%d]: allocated %llu bytes at %llx [%s]\n", id, (unsigned long long) size, ptr);
+#endif
+
+ return ptr;
+}
+
+static void ggml_cuda_pool_free_vmm(void * ptr, size_t size) {
+ scoped_spin_lock lock(g_cuda_pool_lock);
+ int id;
+ CUDA_CHECK(cudaGetDevice(&id));
+
+#ifdef DEBUG_CUDA_MALLOC
+ printf("cuda pool[%d]: freed %llu bytes at %llx\n", id, (unsigned long long) size, ptr);
+#endif
+
+ g_cuda_pool_used[id] -= size;
+
+ // all deallocations must be in reverse order of the allocations
+ GGML_ASSERT(ptr == (void *) (g_cuda_pool_addr[id] + g_cuda_pool_used[id]));
+}
+
+static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
+ int id;
+ CUDA_CHECK(cudaGetDevice(&id));
+ if (g_device_caps[id].vmm) {
+ return ggml_cuda_pool_malloc_vmm(size, actual_size);
+ } else {
+ return ggml_cuda_pool_malloc_leg(size, actual_size);
+ }
+}
+
+static void ggml_cuda_pool_free(void * ptr, size_t size) {
+ int id;
+ CUDA_CHECK(cudaGetDevice(&id));
+ if (g_device_caps[id].vmm) {
+ ggml_cuda_pool_free_vmm(ptr, size);
+ } else {
+ ggml_cuda_pool_free_leg(ptr, size);
+ }
}
+#else
+#define ggml_cuda_pool_malloc ggml_cuda_pool_malloc_leg
+#define ggml_cuda_pool_free ggml_cuda_pool_free_leg
+#endif // !defined(GGML_USE_HIPBLAS)
+
+template<typename T>
+struct cuda_pool_alloc {
+ T * ptr = nullptr;
+ size_t actual_size = 0;
+
+ // size is in number of elements
+ T * alloc(size_t size) {
+ GGML_ASSERT(ptr == nullptr);
+ ptr = (T *) ggml_cuda_pool_malloc(size * sizeof(T), &this->actual_size);
+ return ptr;
+ }
+
+ cuda_pool_alloc(size_t size) {
+ alloc(size);
+ }
+
+ ~cuda_pool_alloc() {
+ if (ptr != nullptr) {
+ ggml_cuda_pool_free(ptr, actual_size);
+ }
+ }
+
+ T * get() {
+ return ptr;
+ }
+
+ cuda_pool_alloc() = default;
+ cuda_pool_alloc(const cuda_pool_alloc &) = delete;
+ cuda_pool_alloc(cuda_pool_alloc &&) = delete;
+ cuda_pool_alloc& operator=(const cuda_pool_alloc &) = delete;
+ cuda_pool_alloc& operator=(cuda_pool_alloc &&) = delete;
+};
static bool g_cublas_loaded = false;
#endif
fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count);
for (int id = 0; id < g_device_count; ++id) {
+ int device_vmm = 0;
+
+#if !defined(GGML_USE_HIPBLAS)
+ CUdevice device;
+ CU_CHECK(cuDeviceGet(&device, id));
+ CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
+
+ if (device_vmm) {
+ CUmemAllocationProp alloc_prop = {};
+ alloc_prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
+ alloc_prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
+ alloc_prop.location.id = id;
+ CU_CHECK(cuMemGetAllocationGranularity(&g_device_caps[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
+ }
+#endif // !defined(GGML_USE_HIPBLAS)
+ g_device_caps[id].vmm = !!device_vmm;
+
cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
- fprintf(stderr, " Device %d: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor);
+ fprintf(stderr, " Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
g_tensor_split[id] = total_vram;
total_vram += prop.totalGlobalMem;
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- g_compute_capabilities[id] = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
+ g_device_caps[id].cc = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
#else
- g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
+ g_device_caps[id].cc = 100*prop.major + 10*prop.minor;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
}
for (int id = 0; id < g_device_count; ++id) {
int64_t max_compute_capability = INT_MIN;
for (int64_t id = 0; id < g_device_count; ++id) {
if (g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
- if (min_compute_capability > g_compute_capabilities[id]) {
- min_compute_capability = g_compute_capabilities[id];
+ if (min_compute_capability > g_device_caps[id].cc) {
+ min_compute_capability = g_device_caps[id].cc;
}
- if (max_compute_capability < g_compute_capabilities[id]) {
- max_compute_capability = g_compute_capabilities[id];
+ if (max_compute_capability < g_device_caps[id].cc) {
+ max_compute_capability = g_device_caps[id].cc;
}
}
}
// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
#ifdef GGML_CUDA_F16
- size_t ash;
- dfloat * src1_dfloat = nullptr; // dfloat == half
+ cuda_pool_alloc<half> src1_dfloat_a;
+ half * src1_dfloat = nullptr; // dfloat == half
bool src1_convert_f16 =
src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
if (src1_convert_f16) {
- src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
+ src1_dfloat = src1_dfloat_a.alloc(ne00);
ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00,
ne00, 1, sizeof(float), 0, 0,
ne00, 1, sizeof(half), 0, 0, stream);
break;
}
-#ifdef GGML_CUDA_F16
- if (src1_convert_f16) {
- ggml_cuda_pool_free(src1_dfloat, ash);
- }
-#endif // GGML_CUDA_F16
-
(void) src1;
(void) dst;
(void) src1_ddq_i;
// ldc == nrows of the matrix that cuBLAS writes into
int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
- const int compute_capability = g_compute_capabilities[id];
+ const int compute_capability = g_device_caps[id].cc;
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
- half * src0_as_f16 = nullptr;
- size_t src0_as = 0;
+ cuda_pool_alloc<half> src0_as_f16;
if (src0->type != GGML_TYPE_F16) {
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
GGML_ASSERT(to_fp16_cuda != nullptr);
size_t ne = row_diff*ne00;
- src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as);
- to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
+ src0_as_f16.alloc(ne);
+ to_fp16_cuda(src0_dd_i, src0_as_f16.get(), ne, stream);
}
- const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
+ const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();
- half * src1_as_f16 = nullptr;
- size_t src1_as = 0;
+ cuda_pool_alloc<half> src1_as_f16;
if (src1->type != GGML_TYPE_F16) {
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
GGML_ASSERT(to_fp16_cuda != nullptr);
size_t ne = src1_ncols*ne10;
- src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
- to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
+ src1_as_f16.alloc(ne);
+ to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
}
- const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
- size_t dst_as = 0;
- half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
+ const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
+ cuda_pool_alloc<half> dst_f16(row_diff*src1_ncols);
const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;
CUBLAS_CHECK(
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
row_diff, src1_ncols, ne10,
- &alpha_f16, src0_ptr, CUDA_R_16F, ne00,
- src1_ptr, CUDA_R_16F, ne10,
- &beta_f16, dst_f16, CUDA_R_16F, ldc,
+ &alpha_f16, src0_ptr, CUDA_R_16F, ne00,
+ src1_ptr, CUDA_R_16F, ne10,
+ &beta_f16, dst_f16.get(), CUDA_R_16F, ldc,
CUBLAS_COMPUTE_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
- to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
-
- ggml_cuda_pool_free(dst_f16, dst_as);
-
- if (src0_as != 0) {
- ggml_cuda_pool_free(src0_as_f16, src0_as);
- }
-
- if (src1_as != 0) {
- ggml_cuda_pool_free(src1_as_f16, src1_as);
- }
+ to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
}
else {
- float * src0_ddq_as_f32 = nullptr;
- size_t src0_as = 0;
+ cuda_pool_alloc<float> src0_ddq_as_f32;
if (src0->type != GGML_TYPE_F32) {
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
GGML_ASSERT(to_fp32_cuda != nullptr);
- src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
- to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
+ src0_ddq_as_f32.alloc(row_diff*ne00);
+ to_fp32_cuda(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
}
- const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
+ const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
const float alpha = 1.0f;
const float beta = 0.0f;
&alpha, src0_ddf_i, ne00,
src1_ddf_i, ne10,
&beta, dst_dd_i, ldc));
-
- if (src0_as != 0) {
- ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
- }
}
(void) dst;
float * src1_ddf = nullptr;
float * dst_ddf = nullptr;
- // as = actual size
- size_t src0_asf = 0;
- size_t src1_asf = 0;
- size_t dst_asf = 0;
+ cuda_pool_alloc<float> src0_f;
+ cuda_pool_alloc<float> src1_f;
+ cuda_pool_alloc<float> dst_f;
ggml_cuda_set_device(g_main_device);
- const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
+ cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
if (src0_on_device) {
src0_ddf = (float *) src0_extra->data_device[g_main_device];
} else {
- src0_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_asf);
+ src0_ddf = src0_f.alloc(ggml_nelements(src0));
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream));
}
if (src1_on_device) {
src1_ddf = (float *) src1_extra->data_device[g_main_device];
} else {
- src1_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf);
+ src1_ddf = src1_f.alloc(ggml_nelements(src1));
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf, src1, 0, 0, 0, nrows1, main_stream));
}
}
if (dst_on_device) {
dst_ddf = (float *) dst_extra->data_device[g_main_device];
} else {
- dst_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(dst), &dst_asf);
+ dst_ddf = dst_f.alloc(ggml_nelements(dst));
}
// do the computation
CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
}
- if (src0_asf > 0) {
- ggml_cuda_pool_free(src0_ddf, src0_asf);
- }
- if (src1_asf > 0) {
- ggml_cuda_pool_free(src1_ddf, src1_asf);
- }
- if (dst_asf > 0) {
- ggml_cuda_pool_free(dst_ddf, dst_asf);
- }
-
if (dst->backend == GGML_BACKEND_CPU) {
CUDA_CHECK(cudaDeviceSynchronize());
}
CUDA_CHECK(ggml_cuda_set_device(id));
// free buffers again when done
- if (src0_as[id] > 0) {
- ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
- }
- if (src1_asf[id] > 0) {
- ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
+ if (dst_as[id] > 0) {
+ ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
}
if (src1_asq[id] > 0) {
ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]);
}
- if (dst_as[id] > 0) {
- ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
+ if (src1_asf[id] > 0) {
+ ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
+ }
+ if (src0_as[id] > 0) {
+ ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
}
}
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
GGML_ASSERT(to_fp16_cuda != nullptr);
- size_t src1_as = 0;
- half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
- to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
+ cuda_pool_alloc<half> src1_as_f16(ne1);
+ to_fp16_cuda(src1_ddf, src1_as_f16.get(), ne1, main_stream);
- size_t dst_as = 0;
-
- half * dst_f16 = nullptr;
- char * dst_t = nullptr;
+ cuda_pool_alloc<half> dst_f16;
+ char * dst_t;
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
cudaDataType_t cu_data_type = CUDA_R_16F;
const void * beta = &beta_f16;
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
- dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
- dst_t = (char *) dst_f16;
+ dst_t = (char *) dst_f16.alloc(ne);
nbd2 /= sizeof(float) / sizeof(half);
nbd3 /= sizeof(float) / sizeof(half);
CUBLAS_CHECK(
cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
- alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
- (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
- beta, ( char *) dst_t, cu_data_type, ne01, dst->nb[2]/sizeof(float), // strideC
+ alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
+ (const char *) src1_as_f16.get(), CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
+ beta, ( char *) dst_t, cu_data_type, ne01, dst->nb[2]/sizeof(float), // strideC
ne12*ne13,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// use cublasGemmBatchedEx
const int ne23 = ne12*ne13;
- const void ** ptrs_src = nullptr;
- void ** ptrs_dst = nullptr;
-
- size_t ptrs_src_s = 0;
- size_t ptrs_dst_s = 0;
-
- ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
- ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);
+ cuda_pool_alloc<const void *> ptrs_src(2*ne23);
+ cuda_pool_alloc< void *> ptrs_dst(1*ne23);
dim3 block_dims(ne13, ne12);
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
- src0_as_f16, src1_as_f16, dst_t,
- ptrs_src, ptrs_dst,
+ src0_as_f16, src1_as_f16.get(), dst_t,
+ ptrs_src.get(), ptrs_dst.get(),
ne12, ne13,
ne23,
nb02, nb03,
CUBLAS_CHECK(
cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
- alpha, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
- (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
- beta, ( void **) (ptrs_dst + 0*ne23), cu_data_type, ne01,
+ alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
+ (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
+ beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne01,
ne23,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
-
- if (ptrs_src_s != 0) {
- ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
- }
- if (ptrs_dst_s != 0) {
- ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
- }
}
#endif
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
- to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
-
- ggml_cuda_pool_free(dst_f16, dst_as);
+ to_fp32_cuda(dst_f16.get(), dst_ddf, ne, main_stream);
}
-
- ggml_cuda_pool_free(src1_as_f16, src1_as);
}
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
int64_t min_compute_capability = INT_MAX;
for (int64_t id = 0; id < g_device_count; ++id) {
- if (min_compute_capability > g_compute_capabilities[id] && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
- min_compute_capability = g_compute_capabilities[id];
+ if (min_compute_capability > g_device_caps[id].cc && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
+ min_compute_capability = g_device_caps[id].cc;
}
}
ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
}
} else {
- size_t as_src1, as_dst;
- char * src1_contiguous = (char *) ggml_cuda_pool_malloc(sizeof(float)*ggml_nelements(src1), &as_src1);
- char * dst_contiguous = (char *) ggml_cuda_pool_malloc(sizeof(float)*ggml_nelements(dst), &as_dst);
+ cuda_pool_alloc<char> src1_contiguous(sizeof(float)*ggml_nelements(src1));
+ cuda_pool_alloc<char> dst_contiguous(sizeof(float)*ggml_nelements(dst));
- src1_row_extra.data_device[g_main_device] = src1_contiguous;
- dst_row_extra.data_device[g_main_device] = dst_contiguous;
+ src1_row_extra.data_device[g_main_device] = src1_contiguous.get();
+ dst_row_extra.data_device[g_main_device] = dst_contiguous.get();
const cudaMemcpyKind src1_kind = src1->backend == GGML_BACKEND_CPU ?
cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
GGML_ASSERT(row_id >= 0 && row_id < n_as);
- CUDA_CHECK(cudaMemcpyAsync(src1_contiguous + num_src1_rows*nb11, src1_original + i01*nb11,
+ CUDA_CHECK(cudaMemcpyAsync(src1_contiguous.get() + num_src1_rows*nb11, src1_original + i01*nb11,
nb11, src1_kind, stream));
num_src1_rows++;
}
GGML_ASSERT(row_id >= 0 && row_id < n_as);
- CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous + num_src1_rows*nb1,
+ CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous.get() + num_src1_rows*nb1,
nb1, dst_kind, stream));
num_src1_rows++;
}
}
-
- ggml_cuda_pool_free(src1_contiguous, as_src1);
- ggml_cuda_pool_free(dst_contiguous, as_dst);
}
if (dst->backend == GGML_BACKEND_CPU) {
static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
void * ptr = ggml_cuda_host_malloc(size);
+
if (ptr == nullptr) {
- return nullptr;
+ // fallback to cpu buffer
+ return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
}
// FIXME: this is a hack to avoid having to implement a new buffer type