]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Custom RoPE + bettter memory management for CUDA (#2295)
authorKawrakow <redacted>
Fri, 21 Jul 2023 14:27:51 +0000 (17:27 +0300)
committerGitHub <redacted>
Fri, 21 Jul 2023 14:27:51 +0000 (17:27 +0300)
* Custom RoPE + bettter memory management for CUDA

* Adjusted look ahead in ggml_cuda_pool_malloc to 5%

This is sufficient it seems.
We end up using about 200 MB less VRAM that way when running
the 13B model with context 8192.

---------

Co-authored-by: Iwan Kawrakow <redacted>
ggml-cuda.cu

index 6537897b990f458b1928e3df3d3425276136051f..c07b5461190f0392d0b6720ea8e10970dbbe01a7 100644 (file)
@@ -2423,20 +2423,53 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
     scoped_spin_lock lock(g_cuda_pool_lock);
     int id;
     CUDA_CHECK(cudaGetDevice(&id));
-
+#ifdef DEBUG_CUDA_MALLOC
+    int nnz = 0;
+    size_t max_size = 0, tot_size = 0;
+#endif
+    size_t best_diff = 1ull << 36;
+    int ibest = -1;
     for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
         cuda_buffer& b = g_cuda_buffer_pool[id][i];
-        if (b.size >= size && b.ptr != nullptr) {
-            void * ptr = b.ptr;
-            *actual_size = b.size;
-            b.ptr = nullptr;
-            b.size = 0;
-            return ptr;
+        if (b.ptr != nullptr) {
+#ifdef DEBUG_CUDA_MALLOC
+            ++nnz;
+            tot_size += b.size;
+            if (b.size > max_size) max_size = b.size;
+#endif
+            if (b.size >= size) {
+                size_t diff = b.size - size;
+                if (diff < best_diff) {
+                    best_diff = diff;
+                    ibest = i;
+                    if (!best_diff) {
+                        void * ptr = b.ptr;
+                        *actual_size = b.size;
+                        b.ptr = nullptr;
+                        b.size = 0;
+                        return ptr;
+                    }
+                }
+            }
         }
     }
+    if (ibest >= 0) {
+        cuda_buffer& b = g_cuda_buffer_pool[id][ibest];
+        void * ptr = b.ptr;
+        *actual_size = b.size;
+        b.ptr = nullptr;
+        b.size = 0;
+        return ptr;
+    }
+#ifdef DEBUG_CUDA_MALLOC
+    fprintf(stderr, "%s: %d buffers, max_size = %u MB, tot_size = %u MB, requested %u MB\n", __func__, nnz,
+            (uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024));
+#endif
     void * ptr;
-    CUDA_CHECK(cudaMalloc((void **) &ptr, size));
-    *actual_size = size;
+    size_t look_ahead_size = (size_t) (1.05 * size);
+    look_ahead_size = 256 * ((look_ahead_size + 255)/256);
+    CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
+    *actual_size = look_ahead_size;
     return ptr;
 }
 
@@ -2955,8 +2988,13 @@ inline void ggml_cuda_op_rope(
     const int mode   = ((int32_t *) src1->data)[2];
     const int n_ctx  = ((int32_t *) src1->data)[3];
 
-    const float theta_scale = powf(10000.0, -2.0f/n_dims);
-    const float p = ((mode & 1) == 0 ? n_past + i02 : i02);
+    // RoPE alteration for extended context
+    float freq_base, freq_scale;
+    memcpy(&freq_base,  (int32_t *) src1->data + 4, sizeof(float));
+    memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
+
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+    const float p = (((mode & 1) == 0 ? n_past + i02 : i02)) * freq_scale;
 
     bool is_glm = mode & 4;