// rope == RoPE == rotary positional embedding
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0,
const float p_delta, const int p_delta_rows, const float theta_scale) {
- const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);
+ const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (col >= ncols) {
return;
}
- const int row = blockDim.y*blockIdx.y + threadIdx.y;
+ const int row = blockDim.x*blockIdx.x + threadIdx.x;
const int i = row*ncols + col;
const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
}
static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
- const int col = blockDim.x*blockIdx.x + threadIdx.x;
- const int row = blockDim.y*blockIdx.y + threadIdx.y;
+ const int col = blockDim.y*blockIdx.y + threadIdx.y;
+ const int row = blockDim.x*blockIdx.x + threadIdx.x;
if (col >= ncols) {
return;
// values are also not normalized to the maximum value by subtracting it in the exponential function
// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
- const int row = blockDim.y*blockIdx.y + threadIdx.y;
- const int block_size = blockDim.x;
- const int tid = threadIdx.x;
+ const int row = blockDim.x*blockIdx.x + threadIdx.x;
+ const int block_size = blockDim.y;
+ const int tid = threadIdx.y;
float tmp = 0.0;
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
GGML_ASSERT(nrows % 2 == 0);
- const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1);
+ const dim3 block_dims(1, 2*CUDA_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
- const dim3 block_nums(num_blocks_x, nrows, 1);
+ const dim3 block_nums(nrows, num_blocks_x, 1);
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
}
}
static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
- const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
+ const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1);
const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
- const dim3 block_nums(block_num_x, nrows_x, 1);
+ const dim3 block_nums(nrows_x, block_num_x, 1);
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
}
static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
- const dim3 block_dims(WARP_SIZE, 1, 1);
- const dim3 block_nums(1, nrows_x, 1);
+ const dim3 block_dims(1, WARP_SIZE, 1);
+ const dim3 block_nums(nrows_x, 1, 1);
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
}
return extra;
}
-void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) {
+void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) {
if (scratch && g_scratch_size == 0) {
return;
}
if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) {
const ggml_op src0_op = tensor->src[0]->op;
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW || src0_op == GGML_OP_PERMUTE) {
- ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace);
+ ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace, no_alloc);
}
}
if (tensor->op == GGML_OP_CPY && tensor->src[1]->backend == GGML_BACKEND_CPU) {
- ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace);
+ ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace, no_alloc);
}
tensor->backend = GGML_BACKEND_GPU;
+
+ if (scratch && no_alloc) {
+ return;
+ }
+
struct ggml_tensor_extra_gpu * extra;
const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
tensor->extra = extra;
}
+void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset) {
+ if (g_scratch_size == 0) {
+ return;
+ }
+ if (g_scratch_buffer == nullptr) {
+ CUDA_CHECK(cudaMalloc(&g_scratch_buffer, g_scratch_size));
+ }
+
+ struct ggml_tensor_extra_gpu * extra = ggml_cuda_alloc_temp_tensor_extra();
+
+ const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
+ tensor->op == GGML_OP_VIEW;
+
+ if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) {
+ struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
+ char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
+ size_t view_offset = 0;
+ if (tensor->op == GGML_OP_VIEW) {
+ memcpy(&view_offset, tensor->op_params, sizeof(size_t));
+ }
+ extra->data_device[g_main_device] = src0_ddc + view_offset;
+ } else {
+ extra->data_device[g_main_device] = (char *) g_scratch_buffer + offset;
+ }
+
+ tensor->extra = extra;
+}
+
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
- ggml_cuda_assign_buffers_impl(tensor, true, false);
+ ggml_cuda_assign_buffers_impl(tensor, true, false, false);
+}
+
+void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor) {
+ ggml_cuda_assign_buffers_impl(tensor, true, false, true);
}
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
- ggml_cuda_assign_buffers_impl(tensor, false, false);
+ ggml_cuda_assign_buffers_impl(tensor, false, false, false);
}
void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
- ggml_cuda_assign_buffers_impl(tensor, false, true);
+ ggml_cuda_assign_buffers_impl(tensor, false, true, false);
}
void ggml_cuda_set_main_device(int main_device) {
#include "ggml.h"
-#if !defined(GGML_USE_CUBLAS)
-# include "ggml-alloc.h"
-# define LLAMA_USE_ALLOCATOR
-#else
-# define LLAMA_USE_SCRATCH
-# define LLAMA_MAX_SCRATCH_BUFFERS 16
-#endif
+#include "ggml-alloc.h"
#ifdef GGML_USE_CUBLAS
# include "ggml-cuda.h"
static llama_state g_state;
-//
-// memory sizes (calculated for n_batch == 512)
-//
-
-// computed for n_ctx == 2048
-// TODO: dynamically determine these sizes
-// needs modifications in ggml
-
// available llama models
enum e_model {
MODEL_UNKNOWN,
static const size_t kB = 1024;
static const size_t MB = 1024*1024;
-static std::map<e_model, size_t> MEM_REQ_SCRATCH0(int n_ctx)
-{
- std::map<e_model, size_t> k_sizes = {
- { MODEL_3B, ((size_t) n_ctx / 16ull + 92ull) * MB },
- { MODEL_7B, ((size_t) n_ctx / 16ull + 100ull) * MB },
- { MODEL_13B, ((size_t) n_ctx / 12ull + 120ull) * MB },
- { MODEL_30B, ((size_t) n_ctx / 9ull + 160ull) * MB },
- { MODEL_65B, ((size_t) n_ctx / 6ull + 256ull) * MB }, // guess
- { MODEL_70B, ((size_t) n_ctx / 7ull + 164ull) * MB },
- };
- return k_sizes;
-}
-
-static const std::map<e_model, size_t> & MEM_REQ_SCRATCH1()
-{
- static std::map<e_model, size_t> k_sizes = {
- { MODEL_3B, 128ull * MB },
- { MODEL_7B, 160ull * MB },
- { MODEL_13B, 192ull * MB },
- { MODEL_30B, 256ull * MB },
- { MODEL_65B, 384ull * MB }, // guess
- { MODEL_70B, 304ull * MB },
- };
- return k_sizes;
-}
-
-// used to store the compute graph tensors + non-scratch data
-static const std::map<e_model, size_t> & MEM_REQ_EVAL()
-{
- static std::map<e_model, size_t> k_sizes = {
- { MODEL_3B, 8ull * MB },
- { MODEL_7B, 10ull * MB },
- { MODEL_13B, 12ull * MB },
- { MODEL_30B, 16ull * MB },
- { MODEL_65B, 24ull * MB }, // guess
- { MODEL_70B, 24ull * MB },
- };
- return k_sizes;
-}
-
-// amount of VRAM needed per batch size to hold temporary results
-// the values for 3b are not derived from testing but instead chosen conservatively
-static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_BASE()
-{
- static std::map<e_model, size_t> k_sizes = {
- { MODEL_3B, 512ull * kB },
- { MODEL_7B, 512ull * kB },
- { MODEL_13B, 640ull * kB },
- { MODEL_30B, 768ull * kB },
- { MODEL_65B, 1280ull * kB },
- { MODEL_70B, 1280ull * kB },
- };
- return k_sizes;
-}
-
-// amount of VRAM needed per batch size and context to hold temporary results
-// the values for 3b are not derived from testing but instead chosen conservatively
-static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_PER_CONTEXT()
-{
- static std::map<e_model, size_t> k_sizes = {
- { MODEL_3B, 128ull },
- { MODEL_7B, 128ull },
- { MODEL_13B, 160ull },
- { MODEL_30B, 208ull },
- { MODEL_65B, 256ull },
- { MODEL_70B, 256ull },
- };
- return k_sizes;
-}
-
// default hparams (LLaMA 7B)
struct llama_hparams {
uint32_t n_vocab = 32000;
ggml_metal_free(ctx_metal);
}
#endif
-#ifdef LLAMA_USE_ALLOCATOR
if (alloc) {
ggml_allocr_free(alloc);
}
-#endif
}
std::mt19937 rng;
// memory buffers used to evaluate the model
llama_buffer buf_compute;
-#ifdef LLAMA_USE_ALLOCATOR
llama_buffer buf_alloc;
ggml_allocr * alloc = NULL;
-#endif
-
-#ifdef LLAMA_USE_SCRATCH
- llama_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS];
-
- int buf_last = 0;
- size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
-#endif
#ifdef GGML_USE_METAL
ggml_metal_context * ctx_metal = NULL;
#ifdef GGML_USE_MPI
ggml_mpi_context * ctx_mpi = NULL;
#endif
-
- void use_buf(struct ggml_context * ctx, int i) { // NOLINT
-#if defined(LLAMA_USE_SCRATCH)
- size_t last_size = 0;
-
- if (i == -1) {
- last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
- } else {
- auto & buf = buf_scratch[i];
- last_size = ggml_set_scratch(ctx, { 0, buf.size, buf.data, });
- }
-
- if (buf_last >= 0) {
- buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
- }
-
- buf_last = i;
-#else
- (void) i;
- (void) ctx;
-#endif
- }
-
- size_t get_buf_max_mem(int i) { // NOLINT
-#if defined(LLAMA_USE_SCRATCH)
- return buf_max_size[i];
-#else
- (void) i;
- return 0;
-#endif
- }
};
//
// prepare memory for the weights
size_t vram_weights = 0;
- size_t vram_scratch = 0;
{
const uint32_t n_embd = hparams.n_embd;
const uint32_t n_embd_gqa = hparams.n_embd_gqa();
ctx_size +
mmapped_size - vram_weights; // weights in VRAM not in memory
-#ifndef LLAMA_USE_ALLOCATOR
- mem_required +=
- MEM_REQ_SCRATCH0(hparams.n_ctx).at(model.type) +
- MEM_REQ_SCRATCH1().at(model.type) +
- MEM_REQ_EVAL().at(model.type);
-#endif
-
// this is the memory required by one llama_state
const size_t mem_required_state =
scale*hparams.kv_size();
LLAMA_LOG_INFO("%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__,
mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
- (void) vram_scratch;
(void) n_batch;
-#ifdef GGML_USE_CUBLAS
- if (low_vram) {
- LLAMA_LOG_INFO("%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__);
- ggml_cuda_set_scratch_size(0); // disable scratch
- } else {
- const size_t vram_scratch_base = VRAM_REQ_SCRATCH_BASE().at(model.type);
- const size_t vram_scratch_per_context = VRAM_REQ_SCRATCH_PER_CONTEXT().at(model.type);
- vram_scratch = n_batch * (vram_scratch_base + n_ctx * vram_scratch_per_context);
- ggml_cuda_set_scratch_size(vram_scratch);
- if (n_gpu_layers > 0) {
- LLAMA_LOG_INFO("%s: allocating batch_size x (%zd kB + n_ctx x %zd B) = %zd MB VRAM for the scratch buffer\n",
- __func__, vram_scratch_base / kB, vram_scratch_per_context,
- (vram_scratch + MB - 1) / MB); // round up
- }
- }
-#endif // GGML_USE_CUBLAS
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n",
__func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
- LLAMA_LOG_INFO("%s: total VRAM used: %zu MB\n",
- __func__, (vram_weights + vram_scratch + vram_kv_cache + MB - 1) / MB); // round up
+ LLAMA_LOG_INFO("%s: VRAM used: %zu MB\n",
+ __func__, (vram_weights + vram_kv_cache + MB - 1) / MB); // round up
#else
(void) n_gpu_layers;
#endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
/*.no_alloc =*/ false,
};
-#ifdef LLAMA_USE_ALLOCATOR
params.no_alloc = true;
-#endif
struct ggml_context * ctx0 = ggml_init(params);
if (tokens) {
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
-#ifdef LLAMA_USE_ALLOCATOR
ggml_allocr_alloc(lctx.alloc, inp_tokens);
if (!ggml_allocr_is_measure(lctx.alloc)) {
memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
}
-#else
- memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
-#endif
ggml_set_name(inp_tokens, "inp_tokens");
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N);
-#ifdef LLAMA_USE_ALLOCATOR
ggml_allocr_alloc(lctx.alloc, inpL);
if (!ggml_allocr_is_measure(lctx.alloc)) {
memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
}
-#else
- memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
-#endif
}
const int i_gpu_start = n_layer - n_gpu_layers;
#ifdef GGML_USE_CUBLAS
if (n_gpu_layers > n_layer) {
- offload_func_nr = ggml_cuda_assign_buffers;
+ offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
}
if (n_gpu_layers > n_layer + 1) {
- offload_func_v = ggml_cuda_assign_buffers;
+ offload_func_v = ggml_cuda_assign_buffers_no_alloc;
}
if (n_gpu_layers > n_layer + 2) {
- offload_func_kq = ggml_cuda_assign_buffers;
+ offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
}
#endif // GGML_USE_CUBLAS
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-#ifdef LLAMA_USE_ALLOCATOR
ggml_allocr_alloc(lctx.alloc, KQ_scale);
if (!ggml_allocr_is_measure(lctx.alloc)) {
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
}
-#else
- ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
-#endif
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
for (int il = 0; il < n_layer; ++il) {
#ifdef GGML_USE_CUBLAS
if (il >= i_gpu_start) {
- offload_func = ggml_cuda_assign_buffers;
+ offload_func = ggml_cuda_assign_buffers_no_alloc;
}
#endif // GGML_USE_CUBLAS
struct ggml_tensor * inpSA = inpL;
- lctx.use_buf(ctx0, 0);
-
// norm
{
cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
ggml_set_name(cur, "result_wo");
}
- lctx.use_buf(ctx0, 1);
-
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
offload_func(inpFF);
ggml_set_name(inpFF, "inpFF");
inpL = cur;
}
- lctx.use_buf(ctx0, 0);
-
// norm
{
cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
cur = ggml_mul_mat(ctx0, model.output, cur);
ggml_set_name(cur, "result_output");
- lctx.use_buf(ctx0, -1);
-
// logits -> probs
//cur = ggml_soft_max_inplace(ctx0, cur);
mem_per_token = ggml_used_mem(ctx0)/N;
}
-#if 0
- LLAMA_LOG_INFO("\n%s: used_mem: eval ctx %.3f MB, scratch %.3f MB %.3f MB, work buf %.3f MB, n_past = %d, N = %d\n", __func__,
- ggml_used_mem(ctx0)/1024.0/1024.0,
- lctx.get_buf_max_mem(0)/1024.0/1024.0,
- lctx.get_buf_max_mem(1)/1024.0/1024.0,
- lctx.work_buffer.size()/1024.0/1024.0,
- n_past, N);
-#endif
-
ggml_free(ctx0);
return gf;
const int64_t n_embd = hparams.n_embd;
const int64_t n_vocab = hparams.n_vocab;
-#ifdef LLAMA_USE_ALLOCATOR
ggml_allocr_reset(lctx.alloc);
-#endif
ggml_cgraph * gf = llama_build_graph(lctx, tokens, embd, n_tokens, n_past);
-#ifdef LLAMA_USE_ALLOCATOR
ggml_allocr_alloc_graph(lctx.alloc, gf);
+
+#ifdef GGML_USE_CUBLAS
+ for (int i = 0; i < gf->n_leafs; i++) {
+ ggml_tensor * node = gf->leafs[i];
+ if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) {
+ ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data);
+ }
+ }
+
+ for (int i = 0; i < gf->n_nodes; i++) {
+ ggml_tensor * node = gf->nodes[i];
+ if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) {
+ ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data);
+ }
+ }
#endif
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
ctx->embedding.resize(hparams.n_embd);
}
-#ifdef LLAMA_USE_ALLOCATOR
{
static const size_t tensor_alignment = 32;
// the compute buffer is used to store the tensor and graph structs, while the allocator buffer is used for the tensor data
LLAMA_LOG_INFO("%s: compute buffer total size = %7.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0);
- // debug - for comparison with scratch buffer
- //size_t prev_req =
- // MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type) +
- // MEM_REQ_SCRATCH1().at(ctx->model.type) +
- // MEM_REQ_EVAL().at(ctx->model.type);
- //LLAMA_LOG_INFO("%s: (debug) equivalent with scratch buffer = %7.2f MB\n", __func__, prev_req / 1024.0 / 1024.0);
-
// recreate allocator with exact memory requirements
ggml_allocr_free(ctx->alloc);
ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal));
}
#endif
- }
-#else
- ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type) + ggml_graph_overhead());
-#endif
-
-#ifdef LLAMA_USE_SCRATCH
- ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type));
- ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type));
+#ifdef GGML_USE_CUBLAS
+ if (params.low_vram) {
+ LLAMA_LOG_INFO("%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__);
+ ggml_cuda_set_scratch_size(0); // disable scratch
+ } else {
+ ggml_cuda_set_scratch_size(alloc_size);
+ LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MB\n", __func__, alloc_size / 1024.0 / 1024.0);
+ }
#endif
+ }
}
#ifdef GGML_USE_METAL