]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda : allocate all temporary ggml_tensor_extra_gpu from a fixed-size buffer (#2220)
authorBach Le <redacted>
Fri, 14 Jul 2023 19:00:58 +0000 (03:00 +0800)
committerGitHub <redacted>
Fri, 14 Jul 2023 19:00:58 +0000 (22:00 +0300)
ggml-cuda.cu

index 73cfc55e54d05da3a0a8ce6a856b4b1fe12266ba..0646fa7b2582dcf06170b5fe2c71a71daed0e538 100644 (file)
@@ -3646,6 +3646,22 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
     delete extra;
 }
 
+static struct ggml_tensor_extra_gpu * g_temp_tensor_extras = nullptr;
+static size_t g_temp_tensor_extra_index = 0;
+
+static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
+    if (g_temp_tensor_extras == nullptr) {
+        g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_MAX_NODES];
+    }
+
+    size_t alloc_index = g_temp_tensor_extra_index;
+    g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_MAX_NODES;
+    struct ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index];
+    memset(extra, 0, sizeof(*extra));
+
+    return extra;
+}
+
 void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) {
     if (scratch && g_scratch_size == 0) {
         return;
@@ -3663,8 +3679,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
     }
 
     tensor->backend = GGML_BACKEND_GPU;
-    struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
-    memset(extra, 0, sizeof(*extra));
+    struct ggml_tensor_extra_gpu * extra;
 
     const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
         tensor->op == GGML_OP_VIEW ||
@@ -3679,10 +3694,12 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
         if (tensor->op == GGML_OP_VIEW) {
             memcpy(&offset, tensor->src[2]->data, sizeof(size_t));
         }
+        extra = ggml_cuda_alloc_temp_tensor_extra();
         extra->data_device[g_main_device] = src0_ddc + offset;
     } else if (tensor->op == GGML_OP_CPY) {
         struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src[1]->extra;
         void * src1_ddv = src1_extra->data_device[g_main_device];
+        extra = ggml_cuda_alloc_temp_tensor_extra();
         extra->data_device[g_main_device] = src1_ddv;
     } else if (scratch) {
         GGML_ASSERT(size <= g_scratch_size);
@@ -3695,6 +3712,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
             CUDA_CHECK(cudaMalloc(&data, g_scratch_size));
             g_scratch_buffer = data;
         }
+        extra = ggml_cuda_alloc_temp_tensor_extra();
         extra->data_device[g_main_device] = data + g_scratch_offset;
 
         g_scratch_offset += size;
@@ -3704,6 +3722,8 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
         void * data;
         CUDA_CHECK(cudaMalloc(&data, size));
         CUDA_CHECK(cudaMemset(data, 0, size));
+        extra = new ggml_tensor_extra_gpu;
+        memset(extra, 0, sizeof(*extra));
         extra->data_device[g_main_device] = data;
     }