]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml-cuda : use graph allocator (#2684)
authorslaren <redacted>
Tue, 22 Aug 2023 13:25:19 +0000 (15:25 +0200)
committerGitHub <redacted>
Tue, 22 Aug 2023 13:25:19 +0000 (15:25 +0200)
use a different function for no_alloc to avoid breaking backwards compat, fixes lora

remove 512 n_batch limit

fixed 2048 batch size

cleanup

Co-authored-by: Johannes Gäßler <redacted>
common/common.cpp
ggml-cuda.cu
ggml-cuda.h
llama.cpp

index d7e1a5725b4837c9756228be5ca1ffc8b123a483..1623ba21f461ae4939e64bfb9ac1fe876a33d6ff 100644 (file)
@@ -289,7 +289,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.n_batch = std::stoi(argv[i]);
-            params.n_batch = std::min(512, params.n_batch);
         } else if (arg == "--keep") {
             if (++i >= argc) {
                 invalid_param = true;
index c0fb9fb650e0d22b04a444ef9056b8129f0ac392..8ab29bb2080249d5a182bd29db3083bd9b42aeef 100644 (file)
@@ -3887,13 +3887,13 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
 // rope == RoPE == rotary positional embedding
 static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0,
                                 const float p_delta, const int p_delta_rows, const float theta_scale) {
-    const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);
+    const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
     if (col >= ncols) {
         return;
     }
 
-    const int row = blockDim.y*blockIdx.y + threadIdx.y;
+    const int row = blockDim.x*blockIdx.x + threadIdx.x;
     const int i = row*ncols + col;
 
     const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
@@ -3965,8 +3965,8 @@ static __global__ void alibi_f32(const float * x, float * dst, const int ncols,
 }
 
 static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
-    const int col = blockDim.x*blockIdx.x + threadIdx.x;
-    const int row = blockDim.y*blockIdx.y + threadIdx.y;
+    const int col = blockDim.y*blockIdx.y + threadIdx.y;
+    const int row = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (col >= ncols) {
         return;
@@ -3982,9 +3982,9 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
 // values are also not normalized to the maximum value by subtracting it in the exponential function
 // theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
 static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
-    const int row = blockDim.y*blockIdx.y + threadIdx.y;
-    const int block_size = blockDim.x;
-    const int tid = threadIdx.x;
+    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+    const int block_size = blockDim.y;
+    const int tid = threadIdx.y;
 
     float tmp = 0.0;
 
@@ -4776,9 +4776,9 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
 static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
                           const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
     GGML_ASSERT(nrows % 2 == 0);
-    const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1);
+    const dim3 block_dims(1, 2*CUDA_ROPE_BLOCK_SIZE, 1);
     const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
-    const dim3 block_nums(num_blocks_x, nrows, 1);
+    const dim3 block_nums(nrows, num_blocks_x, 1);
     rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
 }
 
@@ -4800,15 +4800,15 @@ static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const
 }
 
 static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
-    const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
+    const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1);
     const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
-    const dim3 block_nums(block_num_x, nrows_x, 1);
+    const dim3 block_nums(nrows_x, block_num_x, 1);
     diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
 }
 
 static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
-    const dim3 block_dims(WARP_SIZE, 1, 1);
-    const dim3 block_nums(1, nrows_x, 1);
+    const dim3 block_dims(1, WARP_SIZE, 1);
+    const dim3 block_nums(nrows_x, 1, 1);
     soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
 }
 
@@ -6313,7 +6313,7 @@ static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
     return extra;
 }
 
-void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) {
+void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) {
     if (scratch && g_scratch_size == 0) {
         return;
     }
@@ -6322,14 +6322,19 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
     if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) {
         const ggml_op src0_op = tensor->src[0]->op;
         if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW || src0_op == GGML_OP_PERMUTE) {
-            ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace);
+            ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace, no_alloc);
         }
     }
     if (tensor->op == GGML_OP_CPY && tensor->src[1]->backend == GGML_BACKEND_CPU) {
-        ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace);
+        ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace, no_alloc);
     }
 
     tensor->backend = GGML_BACKEND_GPU;
+
+    if (scratch && no_alloc) {
+        return;
+    }
+
     struct ggml_tensor_extra_gpu * extra;
 
     const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
@@ -6381,16 +6386,48 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
     tensor->extra = extra;
 }
 
+void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset) {
+    if (g_scratch_size == 0) {
+        return;
+    }
+    if (g_scratch_buffer == nullptr) {
+        CUDA_CHECK(cudaMalloc(&g_scratch_buffer, g_scratch_size));
+    }
+
+    struct ggml_tensor_extra_gpu * extra = ggml_cuda_alloc_temp_tensor_extra();
+
+    const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
+        tensor->op == GGML_OP_VIEW;
+
+    if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) {
+        struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
+        char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
+        size_t view_offset = 0;
+        if (tensor->op == GGML_OP_VIEW) {
+            memcpy(&view_offset, tensor->op_params, sizeof(size_t));
+        }
+        extra->data_device[g_main_device] = src0_ddc + view_offset;
+    } else {
+        extra->data_device[g_main_device] = (char *) g_scratch_buffer + offset;
+    }
+
+    tensor->extra = extra;
+}
+
 void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
-    ggml_cuda_assign_buffers_impl(tensor, true, false);
+    ggml_cuda_assign_buffers_impl(tensor, true, false, false);
+}
+
+void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor) {
+    ggml_cuda_assign_buffers_impl(tensor, true, false, true);
 }
 
 void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
-    ggml_cuda_assign_buffers_impl(tensor, false, false);
+    ggml_cuda_assign_buffers_impl(tensor, false, false, false);
 }
 
 void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
-    ggml_cuda_assign_buffers_impl(tensor, false, true);
+    ggml_cuda_assign_buffers_impl(tensor, false, true, false);
 }
 
 void ggml_cuda_set_main_device(int main_device) {
index cad05f5fa47ab68e7692ddf16b6a8013825a72a3..f66bb16786af914867c5a49cd107f9ff85267e4d 100644 (file)
@@ -16,9 +16,14 @@ GGML_API bool   ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const str
 GGML_API void   ggml_cuda_set_tensor_split(const float * tensor_split);
 GGML_API void   ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
 GGML_API void   ggml_cuda_free_data(struct ggml_tensor * tensor);
+
 GGML_API void   ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
 GGML_API void   ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
 GGML_API void   ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
+
+GGML_API void   ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor);
+GGML_API void   ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset);
+
 GGML_API void   ggml_cuda_set_main_device(int main_device);
 GGML_API void   ggml_cuda_set_mul_mat_q(bool mul_mat_q);
 GGML_API void   ggml_cuda_set_scratch_size(size_t scratch_size);
index c97aaee6967e1dd223d469315000cbbc9a7f077a..8b151dc84c90c5c024e9a1b514910dc3ff34cb5f 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
 
 #include "ggml.h"
 
-#if !defined(GGML_USE_CUBLAS)
-#  include "ggml-alloc.h"
-#  define LLAMA_USE_ALLOCATOR
-#else
-#  define LLAMA_USE_SCRATCH
-#  define LLAMA_MAX_SCRATCH_BUFFERS 16
-#endif
+#include "ggml-alloc.h"
 
 #ifdef GGML_USE_CUBLAS
 #  include "ggml-cuda.h"
@@ -588,14 +582,6 @@ struct llama_state {
 
 static llama_state g_state;
 
-//
-// memory sizes (calculated for n_batch == 512)
-//
-
-// computed for n_ctx == 2048
-// TODO: dynamically determine these sizes
-//       needs modifications in ggml
-
 // available llama models
 enum e_model {
     MODEL_UNKNOWN,
@@ -610,76 +596,6 @@ enum e_model {
 static const size_t kB = 1024;
 static const size_t MB = 1024*1024;
 
-static std::map<e_model, size_t> MEM_REQ_SCRATCH0(int n_ctx)
-{
-    std::map<e_model, size_t> k_sizes = {
-        { MODEL_3B,   ((size_t) n_ctx / 16ull +  92ull) * MB },
-        { MODEL_7B,   ((size_t) n_ctx / 16ull + 100ull) * MB },
-        { MODEL_13B,  ((size_t) n_ctx / 12ull + 120ull) * MB },
-        { MODEL_30B,  ((size_t) n_ctx /  9ull + 160ull) * MB },
-        { MODEL_65B,  ((size_t) n_ctx /  6ull + 256ull) * MB }, // guess
-        { MODEL_70B,  ((size_t) n_ctx /  7ull + 164ull) * MB },
-    };
-    return k_sizes;
-}
-
-static const std::map<e_model, size_t> & MEM_REQ_SCRATCH1()
-{
-    static std::map<e_model, size_t> k_sizes = {
-        { MODEL_3B,  128ull * MB },
-        { MODEL_7B,  160ull * MB },
-        { MODEL_13B, 192ull * MB },
-        { MODEL_30B, 256ull * MB },
-        { MODEL_65B, 384ull * MB }, // guess
-        { MODEL_70B, 304ull * MB },
-    };
-    return k_sizes;
-}
-
-// used to store the compute graph tensors + non-scratch data
-static const std::map<e_model, size_t> & MEM_REQ_EVAL()
-{
-    static std::map<e_model, size_t> k_sizes = {
-        { MODEL_3B,   8ull * MB },
-        { MODEL_7B,  10ull * MB },
-        { MODEL_13B, 12ull * MB },
-        { MODEL_30B, 16ull * MB },
-        { MODEL_65B, 24ull * MB }, // guess
-        { MODEL_70B, 24ull * MB },
-    };
-    return k_sizes;
-}
-
-// amount of VRAM needed per batch size to hold temporary results
-// the values for 3b are not derived from testing but instead chosen conservatively
-static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_BASE()
-{
-    static std::map<e_model, size_t> k_sizes = {
-        { MODEL_3B,   512ull * kB },
-        { MODEL_7B,   512ull * kB },
-        { MODEL_13B,  640ull * kB },
-        { MODEL_30B,  768ull * kB },
-        { MODEL_65B, 1280ull * kB },
-        { MODEL_70B, 1280ull * kB },
-    };
-    return k_sizes;
-}
-
-// amount of VRAM needed per batch size and context to hold temporary results
-// the values for 3b are not derived from testing but instead chosen conservatively
-static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_PER_CONTEXT()
-{
-    static std::map<e_model, size_t> k_sizes = {
-        { MODEL_3B,  128ull },
-        { MODEL_7B,  128ull },
-        { MODEL_13B, 160ull },
-        { MODEL_30B, 208ull },
-        { MODEL_65B, 256ull },
-        { MODEL_70B, 256ull },
-    };
-    return k_sizes;
-}
-
 // default hparams (LLaMA 7B)
 struct llama_hparams {
     uint32_t n_vocab     = 32000;
@@ -857,11 +773,9 @@ struct llama_context {
             ggml_metal_free(ctx_metal);
         }
 #endif
-#ifdef LLAMA_USE_ALLOCATOR
         if (alloc) {
             ggml_allocr_free(alloc);
         }
-#endif
     }
 
     std::mt19937 rng;
@@ -901,17 +815,8 @@ struct llama_context {
     // memory buffers used to evaluate the model
     llama_buffer buf_compute;
 
-#ifdef LLAMA_USE_ALLOCATOR
     llama_buffer buf_alloc;
     ggml_allocr * alloc = NULL;
-#endif
-
-#ifdef LLAMA_USE_SCRATCH
-    llama_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS];
-
-    int    buf_last = 0;
-    size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
-#endif
 
 #ifdef GGML_USE_METAL
     ggml_metal_context * ctx_metal = NULL;
@@ -920,37 +825,6 @@ struct llama_context {
 #ifdef GGML_USE_MPI
     ggml_mpi_context * ctx_mpi = NULL;
 #endif
-
-    void use_buf(struct ggml_context * ctx, int i) { // NOLINT
-#if defined(LLAMA_USE_SCRATCH)
-        size_t last_size = 0;
-
-        if (i == -1) {
-            last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
-        } else {
-            auto & buf = buf_scratch[i];
-            last_size = ggml_set_scratch(ctx, { 0, buf.size, buf.data, });
-        }
-
-        if (buf_last >= 0) {
-            buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
-        }
-
-        buf_last = i;
-#else
-        (void) i;
-        (void) ctx;
-#endif
-    }
-
-    size_t get_buf_max_mem(int i) { // NOLINT
-#if defined(LLAMA_USE_SCRATCH)
-        return buf_max_size[i];
-#else
-        (void) i;
-        return 0;
-#endif
-    }
 };
 
 //
@@ -1620,7 +1494,6 @@ static void llama_model_load_internal(
 
     // prepare memory for the weights
     size_t vram_weights = 0;
-    size_t vram_scratch = 0;
     {
         const uint32_t n_embd     = hparams.n_embd;
         const uint32_t n_embd_gqa = hparams.n_embd_gqa();
@@ -1701,13 +1574,6 @@ static void llama_model_load_internal(
             ctx_size +
             mmapped_size - vram_weights; // weights in VRAM not in memory
 
-#ifndef LLAMA_USE_ALLOCATOR
-        mem_required +=
-            MEM_REQ_SCRATCH0(hparams.n_ctx).at(model.type) +
-            MEM_REQ_SCRATCH1().at(model.type) +
-            MEM_REQ_EVAL().at(model.type);
-#endif
-
         // this is the memory required by one llama_state
         const size_t mem_required_state =
             scale*hparams.kv_size();
@@ -1715,24 +1581,7 @@ static void llama_model_load_internal(
         LLAMA_LOG_INFO("%s: mem required  = %7.2f MB (+ %7.2f MB per state)\n", __func__,
                 mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
 
-        (void) vram_scratch;
         (void) n_batch;
-#ifdef GGML_USE_CUBLAS
-        if (low_vram) {
-            LLAMA_LOG_INFO("%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__);
-            ggml_cuda_set_scratch_size(0); // disable scratch
-        } else {
-            const size_t vram_scratch_base = VRAM_REQ_SCRATCH_BASE().at(model.type);
-            const size_t vram_scratch_per_context = VRAM_REQ_SCRATCH_PER_CONTEXT().at(model.type);
-            vram_scratch = n_batch * (vram_scratch_base + n_ctx * vram_scratch_per_context);
-            ggml_cuda_set_scratch_size(vram_scratch);
-            if (n_gpu_layers > 0) {
-                LLAMA_LOG_INFO("%s: allocating batch_size x (%zd kB + n_ctx x %zd B) = %zd MB VRAM for the scratch buffer\n",
-                        __func__, vram_scratch_base / kB, vram_scratch_per_context,
-                        (vram_scratch + MB - 1) / MB); // round up
-            }
-        }
-#endif // GGML_USE_CUBLAS
 
 #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
         const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
@@ -1769,8 +1618,8 @@ static void llama_model_load_internal(
 
         LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n",
                 __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
-        LLAMA_LOG_INFO("%s: total VRAM used: %zu MB\n",
-                __func__, (vram_weights + vram_scratch + vram_kv_cache + MB - 1) / MB); // round up
+        LLAMA_LOG_INFO("%s: VRAM used: %zu MB\n",
+                __func__, (vram_weights + vram_kv_cache + MB - 1) / MB); // round up
 #else
         (void) n_gpu_layers;
 #endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
@@ -1875,9 +1724,7 @@ static struct ggml_cgraph * llama_build_graph(
         /*.no_alloc   =*/ false,
     };
 
-#ifdef LLAMA_USE_ALLOCATOR
     params.no_alloc = true;
-#endif
 
     struct ggml_context * ctx0 = ggml_init(params);
 
@@ -1889,14 +1736,10 @@ static struct ggml_cgraph * llama_build_graph(
     if (tokens) {
         struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
 
-#ifdef LLAMA_USE_ALLOCATOR
         ggml_allocr_alloc(lctx.alloc, inp_tokens);
         if (!ggml_allocr_is_measure(lctx.alloc)) {
             memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
         }
-#else
-        memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
-#endif
         ggml_set_name(inp_tokens, "inp_tokens");
 
         inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
@@ -1907,14 +1750,10 @@ static struct ggml_cgraph * llama_build_graph(
 
         inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N);
 
-#ifdef LLAMA_USE_ALLOCATOR
         ggml_allocr_alloc(lctx.alloc, inpL);
         if (!ggml_allocr_is_measure(lctx.alloc)) {
             memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
         }
-#else
-        memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
-#endif
     }
 
     const int i_gpu_start = n_layer - n_gpu_layers;
@@ -1931,25 +1770,21 @@ static struct ggml_cgraph * llama_build_graph(
 
 #ifdef GGML_USE_CUBLAS
     if (n_gpu_layers > n_layer) {
-        offload_func_nr = ggml_cuda_assign_buffers;
+        offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
     }
     if (n_gpu_layers > n_layer + 1) {
-        offload_func_v  = ggml_cuda_assign_buffers;
+        offload_func_v  = ggml_cuda_assign_buffers_no_alloc;
     }
     if (n_gpu_layers > n_layer + 2) {
-        offload_func_kq = ggml_cuda_assign_buffers;
+        offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
     }
 #endif // GGML_USE_CUBLAS
 
     struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-#ifdef LLAMA_USE_ALLOCATOR
     ggml_allocr_alloc(lctx.alloc, KQ_scale);
     if (!ggml_allocr_is_measure(lctx.alloc)) {
         ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
     }
-#else
-    ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
-#endif
     ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
 
     for (int il = 0; il < n_layer; ++il) {
@@ -1959,14 +1794,12 @@ static struct ggml_cgraph * llama_build_graph(
 
 #ifdef GGML_USE_CUBLAS
         if (il >= i_gpu_start) {
-            offload_func = ggml_cuda_assign_buffers;
+            offload_func = ggml_cuda_assign_buffers_no_alloc;
         }
 #endif // GGML_USE_CUBLAS
 
         struct ggml_tensor * inpSA = inpL;
 
-        lctx.use_buf(ctx0, 0);
-
         // norm
         {
             cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
@@ -2104,8 +1937,6 @@ static struct ggml_cgraph * llama_build_graph(
             ggml_set_name(cur, "result_wo");
         }
 
-        lctx.use_buf(ctx0, 1);
-
         struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
         offload_func(inpFF);
         ggml_set_name(inpFF, "inpFF");
@@ -2160,8 +1991,6 @@ static struct ggml_cgraph * llama_build_graph(
         inpL = cur;
     }
 
-    lctx.use_buf(ctx0, 0);
-
     // norm
     {
         cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
@@ -2178,8 +2007,6 @@ static struct ggml_cgraph * llama_build_graph(
     cur = ggml_mul_mat(ctx0, model.output, cur);
     ggml_set_name(cur, "result_output");
 
-    lctx.use_buf(ctx0, -1);
-
     // logits -> probs
     //cur = ggml_soft_max_inplace(ctx0, cur);
 
@@ -2189,15 +2016,6 @@ static struct ggml_cgraph * llama_build_graph(
         mem_per_token = ggml_used_mem(ctx0)/N;
     }
 
-#if 0
-    LLAMA_LOG_INFO("\n%s: used_mem: eval ctx %.3f MB, scratch %.3f MB %.3f MB, work buf %.3f MB, n_past = %d, N = %d\n", __func__,
-            ggml_used_mem(ctx0)/1024.0/1024.0,
-            lctx.get_buf_max_mem(0)/1024.0/1024.0,
-            lctx.get_buf_max_mem(1)/1024.0/1024.0,
-            lctx.work_buffer.size()/1024.0/1024.0,
-            n_past, N);
-#endif
-
     ggml_free(ctx0);
 
     return gf;
@@ -2248,14 +2066,26 @@ static bool llama_eval_internal(
     const int64_t n_embd      = hparams.n_embd;
     const int64_t n_vocab     = hparams.n_vocab;
 
-#ifdef LLAMA_USE_ALLOCATOR
     ggml_allocr_reset(lctx.alloc);
-#endif
 
     ggml_cgraph * gf = llama_build_graph(lctx, tokens, embd, n_tokens, n_past);
 
-#ifdef LLAMA_USE_ALLOCATOR
     ggml_allocr_alloc_graph(lctx.alloc, gf);
+
+#ifdef GGML_USE_CUBLAS
+    for (int i = 0; i < gf->n_leafs; i++) {
+        ggml_tensor * node = gf->leafs[i];
+        if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) {
+            ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data);
+        }
+    }
+
+    for (int i = 0; i < gf->n_nodes; i++) {
+        ggml_tensor * node = gf->nodes[i];
+        if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) {
+            ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data);
+        }
+    }
 #endif
 
     // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
@@ -4319,7 +4149,6 @@ struct llama_context * llama_new_context_with_model(
             ctx->embedding.resize(hparams.n_embd);
         }
 
-#ifdef LLAMA_USE_ALLOCATOR
         {
             static const size_t tensor_alignment = 32;
             // the compute buffer is used to store the tensor and graph structs, while the allocator buffer is used for the tensor data
@@ -4350,13 +4179,6 @@ struct llama_context * llama_new_context_with_model(
 
             LLAMA_LOG_INFO("%s: compute buffer total size = %7.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0);
 
-            // debug - for comparison with scratch buffer
-            //size_t prev_req =
-            //    MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type) +
-            //    MEM_REQ_SCRATCH1().at(ctx->model.type) +
-            //    MEM_REQ_EVAL().at(ctx->model.type);
-            //LLAMA_LOG_INFO("%s: (debug) equivalent with scratch buffer = %7.2f MB\n", __func__, prev_req / 1024.0 / 1024.0);
-
             // recreate allocator with exact memory requirements
             ggml_allocr_free(ctx->alloc);
 
@@ -4367,15 +4189,16 @@ struct llama_context * llama_new_context_with_model(
                 ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal));
             }
 #endif
-        }
-#else
-        ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type) + ggml_graph_overhead());
-#endif
-
-#ifdef LLAMA_USE_SCRATCH
-        ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type));
-        ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type));
+#ifdef GGML_USE_CUBLAS
+            if (params.low_vram) {
+                LLAMA_LOG_INFO("%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__);
+                ggml_cuda_set_scratch_size(0); // disable scratch
+            } else {
+                ggml_cuda_set_scratch_size(alloc_size);
+                LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MB\n", __func__, alloc_size / 1024.0 / 1024.0);
+            }
 #endif
+        }
     }
 
 #ifdef GGML_USE_METAL