]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : grouped-query attention + LLaMAv2 70B support (#2276)
authorGeorgi Gerganov <redacted>
Sun, 23 Jul 2023 12:09:47 +0000 (15:09 +0300)
committerGitHub <redacted>
Sun, 23 Jul 2023 12:09:47 +0000 (15:09 +0300)
* CUDA: GQA implementation

* llama : support for GQA and LLaMAv2 70B

ggml-ci

* py : fix hparams parsing (if-else blocks)

ggml-ci

* py : oh boy ..

ggml-ci

* help : fix gqa value for 70B

ggml-ci

---------

Co-authored-by: JohannesGaessler <redacted>
convert.py
examples/common.cpp
examples/common.h
examples/main/main.cpp
ggml-cuda.cu
llama.cpp
llama.h

index e3f1096e149c4915584d3068c86bf48242b7a492..8d7af06d13c6f5a8dabc20bb6cee1ae3eadb704f 100755 (executable)
@@ -142,9 +142,9 @@ def find_n_mult(n_ff: int, n_embd: int) -> int:
 @dataclass
 class Params:
     n_vocab: int
-    n_embd: int
-    n_mult: int
-    n_head: int
+    n_embd:  int
+    n_mult:  int
+    n_head:  int
     n_layer: int
 
     @staticmethod
@@ -167,11 +167,11 @@ class Params:
         n_head=n_embd // 128 # guessed
 
         return Params(
-            n_vocab=n_vocab,
-            n_embd=n_embd,
-            n_mult=256,
-            n_head=n_head,
-            n_layer=n_layer,
+            n_vocab = n_vocab,
+            n_embd  = n_embd,
+            n_mult  = 256,
+            n_head  = n_head,
+            n_layer = n_layer,
         )
 
     @staticmethod
@@ -179,28 +179,53 @@ class Params:
         config = json.load(open(config_path))
 
         n_vocab = config["vocab_size"];
-        n_embd = config["hidden_size"];
-        n_head = config["num_attention_heads"];
+        n_embd  = config["hidden_size"];
+        n_head  = config["num_attention_heads"];
         n_layer = config["num_hidden_layers"];
-        n_ff = config["intermediate_size"];
+        n_ff    = config["intermediate_size"];
 
         n_mult = find_n_mult(n_ff, n_embd);
 
         return Params(
-            n_vocab=n_vocab,
-            n_embd=n_embd,
-            n_mult=n_mult,
-            n_head=n_head,
-            n_layer=n_layer,
+            n_vocab = n_vocab,
+            n_embd  = n_embd,
+            n_mult  = n_mult,
+            n_head  = n_head,
+            n_layer = n_layer,
+        )
+
+    # LLaMA v2 70B params.json
+    # {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1
+    @staticmethod
+    def loadOriginalParamsJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
+        config = json.load(open(config_path))
+
+        n_vocab = config["vocab_size"];
+        n_embd  = config["dim"];
+        n_head  = config["n_heads"];
+        n_layer = config["n_layers"];
+        n_mult  = config["multiple_of"];
+
+        if n_vocab == -1:
+            n_vocab = model["tok_embeddings.weight"].shape[0]
+
+        return Params(
+            n_vocab = n_vocab,
+            n_embd  = n_embd,
+            n_mult  = n_mult,
+            n_head  = n_head,
+            n_layer = n_layer,
         )
 
     @staticmethod
     def load(model_plus: 'ModelPlus') -> 'Params':
+        hf_config_path   = model_plus.paths[0].parent / "config.json"
         orig_config_path = model_plus.paths[0].parent / "params.json"
-        hf_transformer_config_path = model_plus.paths[0].parent / "config.json"
 
-        if hf_transformer_config_path.exists():
-            params = Params.loadHFTransformerJson(model_plus.model, hf_transformer_config_path)
+        if hf_config_path.exists():
+            params = Params.loadHFTransformerJson(model_plus.model, hf_config_path)
+        elif orig_config_path.exists():
+            params = Params.loadOriginalParamsJson(model_plus.model, orig_config_path)
         else:
             params = Params.guessed(model_plus.model)
 
@@ -1036,8 +1061,7 @@ class OutputFile:
     @staticmethod
     def write_vocab_only(fname_out: Path, vocab: Vocab) -> None:
         of = OutputFile(fname_out)
-        params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0,
-                        n_head=1, n_layer=0)
+        params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0)
         of = OutputFile(fname_out)
         of.write_file_header(params, file_type=GGMLFileType.AllF32)
         of.write_vocab(vocab)
index 6610397656365fdf8367bf55a45b0dec7b1d2c5b..5608ca87f2e0fa4c65a028fef20843f9b371f62d 100644 (file)
@@ -168,6 +168,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.n_ctx = std::stoi(argv[i]);
+        } else if (arg == "-gqa" || arg == "--gqa") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.n_gqa = std::stoi(argv[i]);
         } else if (arg == "--rope-freq-base") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -485,6 +491,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stdout, "  -f FNAME, --file FNAME\n");
     fprintf(stdout, "                        prompt file to start generation.\n");
     fprintf(stdout, "  -n N, --n-predict N   number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
+    fprintf(stdout, "  -c N, --ctx-size N    size of the prompt context (default: %d)\n", params.n_ctx);
+    fprintf(stdout, "  -b N, --batch-size N  batch size for prompt processing (default: %d)\n", params.n_batch);
+    fprintf(stdout, "  -gqa N, --gqa N       grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
     fprintf(stdout, "  --top-k N             top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
     fprintf(stdout, "  --top-p N             top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
     fprintf(stdout, "  --tfs N               tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
@@ -505,7 +514,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stdout, "  --cfg-negative-prompt PROMPT \n");
     fprintf(stdout, "                        negative prompt to use for guidance. (default: empty)\n");
     fprintf(stdout, "  --cfg-scale N         strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
-    fprintf(stdout, "  -c N, --ctx-size N    size of the prompt context (default: %d)\n", params.n_ctx);
     fprintf(stdout, "  --rope-freq-base N    RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
     fprintf(stdout, "  --rope-freq-scale N   RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
     fprintf(stdout, "  --ignore-eos          ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
@@ -513,7 +521,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stdout, "  --memory-f32          use f32 instead of f16 for memory key+value (default: disabled)\n");
     fprintf(stdout, "                        not recommended: doubles context memory required and no measurable increase in quality\n");
     fprintf(stdout, "  --temp N              temperature (default: %.1f)\n", (double)params.temp);
-    fprintf(stdout, "  -b N, --batch-size N  batch size for prompt processing (default: %d)\n", params.n_batch);
     fprintf(stdout, "  --perplexity          compute perplexity over each ctx window of the prompt\n");
     fprintf(stdout, "  --perplexity-lines    compute perplexity over each line of the prompt\n");
     fprintf(stdout, "  --keep                number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
@@ -580,6 +587,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
 
     lparams.n_ctx           = params.n_ctx;
     lparams.n_batch         = params.n_batch;
+    lparams.n_gqa           = params.n_gqa;
     lparams.n_gpu_layers    = params.n_gpu_layers;
     lparams.main_gpu        = params.main_gpu;
     lparams.tensor_split    = params.tensor_split;
index c936de6fad659f87094b81bcdb471f73de6798f0..fb8f6d65f6d8de62ba57dad570e126330fe96d51 100644 (file)
@@ -27,6 +27,7 @@ struct gpt_params {
     int32_t n_predict                       = -1;  // new tokens to predict
     int32_t n_ctx                           = 512; // context size
     int32_t n_batch                         = 512; // batch size for prompt processing (must be >=32 to use BLAS)
+    int32_t n_gqa                           = 1;   // grouped-query attention factor (TODO: move to hparams)
     int32_t n_keep                          = 0;   // number of tokens to keep from initial prompt
     int32_t n_chunks                        = -1;  // max number of chunks to process (-1 = unlimited)
     int32_t n_gpu_layers                    = 0;   // number of layers to store in VRAM
@@ -47,7 +48,7 @@ struct gpt_params {
     int32_t repeat_last_n     = 64;    // last n tokens to penalize (0 = disable penalty, -1 = context size)
     float   frequency_penalty = 0.00f; // 0.0 = disabled
     float   presence_penalty  = 0.00f; // 0.0 = disabled
-    int     mirostat          = 0;     // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
+    int32_t mirostat          = 0;     // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
     float   mirostat_tau      = 5.00f; // target entropy
     float   mirostat_eta      = 0.10f; // learning rate
 
index 4b4cd1de4c8405bc986f9c789c940bf686e0393c..3bd8ba2621c3765a970cb9f33bfa8f3dd0ac225d 100644 (file)
@@ -93,8 +93,8 @@ int main(int argc, char ** argv) {
     }
 
     if (params.n_ctx > 2048) {
-        fprintf(stderr, "%s: warning: base model only supports context sizes no greater than 2048 tokens (%d specified);"
-                " you are on your own\n", __func__, params.n_ctx);
+        // TODO: determine the actual max context of the model (e.g. 4096 for LLaMA v2) and use that instead of 2048
+        fprintf(stderr, "%s: warning: base model only supports context sizes no greater than 2048 tokens (%d specified)\n", __func__, params.n_ctx);
     } else if (params.n_ctx < 8) {
         fprintf(stderr, "%s: warning: minimum context size is 8, using minimum size.\n", __func__);
         params.n_ctx = 8;
index 2c5d15773c875ff334346ae6d0ac9868620fdd82..720447440183083e0877039ace008beaf03e6eff 100644 (file)
@@ -1787,11 +1787,15 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
     }
 }
 
-static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
+static __global__ void mul_mat_p021_f16_f32(
+    const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) {
+
     const half * x = (const half *) vx;
 
     const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
     const int channel = blockDim.z*blockIdx.z + threadIdx.z;
+    const int channel_x = channel / (nchannels_y / nchannels_x);
 
     const int nrows_y = ncols_x;
     const int nrows_dst = nrows_x;
@@ -1807,7 +1811,7 @@ static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const
         }
 
         // x is transposed and permuted
-        const int ix = row_x*nchannels_x*ncols_x + channel*ncols_x + col_x;
+        const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;
         const float xi = __half2float(x[ix]);
 
         const int row_y = col_x;
@@ -1835,12 +1839,13 @@ static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const
 
 static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
     const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
-    const int row_stride_x, const int channel_stride_x) {
+    const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) {
 
     const half * x = (const half *) vx;
 
     const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
     const int channel = blockDim.z*blockIdx.z + threadIdx.z;
+    const int channel_x = channel / channel_x_divisor;
 
     const int nrows_y = ncols_x;
     const int nrows_dst = nrows_x;
@@ -1857,7 +1862,7 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
             break;
         }
 
-        const int ix = channel*channel_stride_x + row_x*row_stride_x + col_x;
+        const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
         const float xi = __half2float(x[ix]);
 
         const int row_y = col_x;
@@ -2366,20 +2371,23 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
     }
 }
 
-static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) {
-    const dim3 block_nums(1, nrows_x, nchannels_x);
+static void ggml_mul_mat_p021_f16_f32_cuda(
+    const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
+    const int nchannels_x, const int nchannels_y, cudaStream_t stream) {
+
+    const dim3 block_nums(1, nrows_x, nchannels_y);
     const dim3 block_dims(WARP_SIZE, 1, 1);
-    mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x);
+    mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y);
 }
 
 static void ggml_mul_mat_vec_nc_f16_f32_cuda(
     const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x,
-    const int nchannels_x, const int channel_stride_x, cudaStream_t stream) {
+    const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) {
 
-    const dim3 block_nums(1, nrows_x, nchannels_x);
+    const dim3 block_nums(1, nrows_x, nchannels_y);
     const dim3 block_dims(WARP_SIZE, 1, 1);
     mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
-        (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x);
+        (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x);
 }
 
 static void ggml_cpy_f32_f32_cuda(
@@ -3143,6 +3151,9 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
     const int64_t ne11 = use_src1 ? src1->ne[1] : 1;
     const int64_t ne12 = use_src1 ? src1->ne[2] : 1;
     const int64_t ne13 = use_src1 ? src1->ne[3] : 1;
+    const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
+
+    GGML_ASSERT(ne03 == ne13);
 
     const int64_t ne0 = dst->ne[0];
     const int64_t ne1 = dst->ne[1];
@@ -3154,12 +3165,19 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
     GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
 
     // strides for iteration over dims 3 and 2
-    const int64_t num_iters = flatten_rows ? 1 : ne02 * ne03;
-    const int64_t stride_mod = flatten_rows ? ne02 * ne03 : 1;
+    const int64_t num_iters_0 = ne02 >= ne12 ? ne02*ne03 : ne12*ne13;
+    const int64_t num_iters = flatten_rows ? 1 : num_iters_0;
+    const int64_t stride_mod = flatten_rows ? num_iters_0 : 1;
     const int64_t src0_stride = ne00 * ne01 * stride_mod;
     const int64_t src1_stride = ne10 * ne11 * stride_mod;
     const int64_t dst_stride = ne0 * ne1 * stride_mod;
 
+    const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01;
+    const int64_t i03_max = flatten_rows ? 1 : ne03;
+    const int64_t i02_max = flatten_rows ? 1 : (ne02 >= ne12 ? ne02 : ne12);
+    const int64_t i02_divisor = ne02 >= ne12 ? 1 : ne12 / ne02;
+    GGML_ASSERT(!(flatten_rows && ne02 < ne12));
+
     const size_t src0_ts = ggml_type_size(src0->type);
     const size_t src0_bs = ggml_blck_size(src0->type);
 
@@ -3176,6 +3194,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
         dst->op == GGML_OP_SCALE || dst->op == GGML_OP_DIAG_MASK_INF || dst->op == GGML_OP_ROPE);
 
     const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
+    GGML_ASSERT(!(split && ne02 < ne12));
 
     const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
 
@@ -3212,7 +3231,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
             row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1];
         } else {
             row_low = 0;
-            row_high = nrows0;
+            row_high = nrows0*i02_divisor;
         }
         if (row_low == row_high) {
             continue;
@@ -3260,16 +3279,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
             dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]);
         }
 
-        const int64_t i03_max = flatten_rows ? 1 : ne03;
-        const int64_t i02_max = flatten_rows ? 1 : ne02;
-        const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01;
-
         for (int64_t i03 = 0; i03 < i03_max; i03++) {
             const int64_t i13 = i03 % ne13;
             for (int64_t i02 = 0; i02 < i02_max; i02++) {
                 const int64_t i12 = i02 % ne12;
 
-                const int64_t i0 = i03*ne02 + i02;
+                const int64_t i0 = i03*i02_max + i02;
 
                 // i0 values that contain the lower/upper rows for a split tensor when using multiple GPUs
                 const int64_t i0_offset_low = row_low/rows_per_iter;
@@ -3303,10 +3318,10 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
                 const int64_t i11 = i13*ne12 + i12;
 
                 // for split tensors the data begins at i0 == i0_offset_low
-                char  * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
-                float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride;
+                char  * src0_ddq_i = src0_ddq[id] + (i0/i02_divisor - i0_offset_low)*src0_stride*src0_ts/src0_bs;
+                float * src0_ddf_i = src0_ddf[id] + (i0/i02_divisor - i0_offset_low)*src0_stride;
                 float * src1_ddf_i = src1_ddf[id] + i11*src1_stride;
-                float * dst_ddf_i  =  dst_ddf[id] + (i0 - i0_offset_low)*dst_stride;
+                float * dst_ddf_i  =  dst_ddf[id] + (i0             - i0_offset_low)*dst_stride;
 
                 // for split tensors the data pointer needs to be rounded down
                 // to the bin edge for i03, i02 bins beyond the first
@@ -3345,11 +3360,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
                     }
                 }
 
-                if (!src0_on_device || !src0_is_contiguous) {
+                if ((!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) {
                     if (src0_is_f32) {
-                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
+                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main));
                     } else {
-                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
+                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main));
                     }
                 }
 
@@ -3503,6 +3518,8 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
     const int64_t ne01 = src0->ne[1];
     const int64_t ne02 = src0->ne[2];
 
+    const int64_t ne12 = src1->ne[2];
+
     CUDA_CHECK(cudaSetDevice(g_main_device));
     cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
 
@@ -3515,7 +3532,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
     struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
     float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
 
-    ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main);
+    ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, cudaStream_main);
 }
 
 void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
@@ -3529,6 +3546,8 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
     const int64_t ne01 = src0->ne[1];
     const int64_t ne02 = src0->ne[2];
 
+    const int64_t ne12 = src1->ne[2];
+
     const int64_t nb01 = src0->nb[1];
     const int64_t nb02 = src0->nb[2];
 
@@ -3547,7 +3566,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
     const int row_stride_x = nb01 / sizeof(half);
     const int channel_stride_x = nb02 / sizeof(half);
 
-    ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main);
+    ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, cudaStream_main);
 }
 
 void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
index 0731c75ad2699a87da1d4369b4fcc3a5887d24bf..5a8453bec3ada4c00bbcc1ed151d9acfb7f54be5 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -67,6 +67,7 @@ enum e_model {
     MODEL_13B,
     MODEL_30B,
     MODEL_65B,
+    MODEL_70B,
 };
 
 static const size_t kB = 1024;
@@ -109,6 +110,7 @@ static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0(int n_ctx)
         { MODEL_13B,  ((size_t) n_ctx / 12ull + 120ull) * MB },
         { MODEL_30B,  ((size_t) n_ctx /  9ull + 160ull) * MB },
         { MODEL_65B,  ((size_t) n_ctx /  6ull + 256ull) * MB }, // guess
+        { MODEL_70B,  ((size_t) n_ctx /  7ull + 164ull) * MB },
     };
     return k_sizes;
 }
@@ -121,6 +123,7 @@ static const std::map<e_model, size_t> & MEM_REQ_SCRATCH1()
         { MODEL_13B, 192ull * MB },
         { MODEL_30B, 256ull * MB },
         { MODEL_65B, 384ull * MB }, // guess
+        { MODEL_70B, 304ull * MB },
     };
     return k_sizes;
 }
@@ -134,6 +137,7 @@ static const std::map<e_model, size_t> & MEM_REQ_EVAL()
         { MODEL_13B, 12ull * MB },
         { MODEL_30B, 16ull * MB },
         { MODEL_65B, 24ull * MB }, // guess
+        { MODEL_70B, 24ull * MB },
     };
     return k_sizes;
 }
@@ -148,6 +152,7 @@ static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_BASE()
         { MODEL_13B,  640ull * kB },
         { MODEL_30B,  768ull * kB },
         { MODEL_65B, 1536ull * kB },
+        { MODEL_70B, 1536ull * kB }, // TODO (likely can be reduced)
     };
     return k_sizes;
 }
@@ -162,19 +167,25 @@ static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_PER_CONTEXT()
         { MODEL_13B, 160ull },
         { MODEL_30B, 208ull },
         { MODEL_65B, 416ull },
+        { MODEL_70B, 416ull }, // TODO (likely can be reduced)
     };
     return k_sizes;
 }
 
 // default hparams (LLaMA 7B)
 struct llama_hparams {
-    uint32_t n_vocab = 32000;
-    uint32_t n_ctx   = 512;   // this is provided as user input?
-    uint32_t n_embd  = 4096;
-    uint32_t n_mult  = 256;
-    uint32_t n_head  = 32;
-    uint32_t n_layer = 32;
-    uint32_t n_rot   = 64;
+    uint32_t n_vocab   = 32000;
+    uint32_t n_ctx     = 512;   // this is provided as user input?
+    uint32_t n_embd    = 4096;
+    uint32_t n_mult    = 256;
+    uint32_t n_head    = 32;
+    uint32_t n_head_kv = 32;
+    uint32_t n_layer   = 32;
+    uint32_t n_rot     = 64;
+
+    // LLaMAv2
+    // TODO: load from model data hparams
+    float f_ffn_mult = 1.0f;
 
     float rope_freq_base  = 10000.0f;
     float rope_freq_scale = 1.0f;
@@ -182,12 +193,24 @@ struct llama_hparams {
     enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16;
 
     bool operator!=(const llama_hparams & other) const {
-        return static_cast<bool>(memcmp(this, &other, sizeof(llama_hparams)));
+        return static_cast<bool>(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT
+    }
+
+    uint32_t n_gqa() const {
+        return n_head/n_head_kv;
+    }
+
+    uint32_t n_embd_head() const {
+        return n_embd/n_head;
+    }
+
+    uint32_t n_embd_gqa() const {
+        return n_embd/n_gqa();
     }
 
     size_t kv_size() const {
         size_t result = 2ull;
-        result *= (size_t) n_embd;
+        result *= (size_t) n_embd_gqa();
         result *= (size_t) n_ctx;
         result *= (size_t) n_layer;
         result *= sizeof(ggml_fp16_t);
@@ -493,12 +516,16 @@ struct llama_file_loader {
     }
     void read_hparams() {
         hparams.n_vocab = file.read_u32();
-        hparams.n_embd = file.read_u32();
-        hparams.n_mult = file.read_u32();
-        hparams.n_head = file.read_u32();
+        hparams.n_embd  = file.read_u32();
+        hparams.n_mult  = file.read_u32();
+        hparams.n_head  = file.read_u32();
         hparams.n_layer = file.read_u32();
-        hparams.n_rot = file.read_u32();
-        hparams.ftype = (enum llama_ftype) file.read_u32();
+        hparams.n_rot   = file.read_u32();
+        hparams.ftype   = (enum llama_ftype) file.read_u32();
+
+        // LLaMAv2
+        // TODO: read from header
+        hparams.n_head_kv = hparams.n_head;
     }
     void read_vocab() {
         vocab.id_to_token.resize(hparams.n_vocab);
@@ -797,7 +824,7 @@ static bool kv_cache_init(
                          ggml_type   wtype,
                                int   n_ctx,
                                int   n_gpu_layers) {
-    const int n_embd  = hparams.n_embd;
+    const int n_embd  = hparams.n_embd_gqa();
     const int n_layer = hparams.n_layer;
 
     const int64_t n_mem      = n_layer*n_ctx;
@@ -841,6 +868,7 @@ struct llama_context_params llama_context_default_params() {
         /*.seed                        =*/ LLAMA_DEFAULT_SEED,
         /*.n_ctx                       =*/ 512,
         /*.n_batch                     =*/ 512,
+        /*.n_gqa                       =*/ 1,
         /*.gpu_layers                  =*/ 0,
         /*.main_gpu                    =*/ 0,
         /*.tensor_split                =*/ nullptr,
@@ -960,6 +988,7 @@ static const char *llama_model_type_name(e_model type) {
         case MODEL_13B: return "13B";
         case MODEL_30B: return "30B";
         case MODEL_65B: return "65B";
+        case MODEL_70B: return "70B";
         default: LLAMA_ASSERT(false);
     }
 }
@@ -970,6 +999,7 @@ static void llama_model_load_internal(
         llama_vocab & vocab,
         int n_ctx,
         int n_batch,
+        int n_gqa,
         int n_gpu_layers,
         int main_gpu,
         const float * tensor_split,
@@ -991,6 +1021,7 @@ static void llama_model_load_internal(
     model.hparams = ml->file_loader->hparams;
     model.n_gpu_layers = n_gpu_layers;
     llama_file_version file_version = ml->file_loader->file_version;
+
     auto & hparams = model.hparams;
 
     {
@@ -1010,11 +1041,25 @@ static void llama_model_load_internal(
 
         hparams.n_ctx = n_ctx;
 
+        // LLaMAv2
+        // TODO: temporary until GGUF
+        LLAMA_ASSERT(hparams.n_head % n_gqa == 0);
+        hparams.n_head_kv = hparams.n_head / n_gqa;
+        if (model.type == e_model::MODEL_65B && n_gqa == 8) {
+            fprintf(stderr, "%s: warning: assuming 70B model based on GQA == %d\n", __func__, n_gqa);
+            model.type = e_model::MODEL_70B;
+            hparams.f_ffn_mult = 1.3f; // from the params.json of the 70B model
+        }
+
         hparams.rope_freq_base  = rope_freq_base;
         hparams.rope_freq_scale = rope_freq_scale;
     }
 
-    const uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
+    // ref: https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/model.py#L194-L199
+    const uint32_t n_ff_raw  = 2*(4*hparams.n_embd)/3;
+    const uint32_t n_ff_mult = hparams.f_ffn_mult*n_ff_raw;
+    const uint32_t n_ff      = ((n_ff_mult + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
+    //const uint32_t n_ff = 28672;
 
     {
         fprintf(stderr, "%s: format     = %s\n",   __func__, llama_file_version_name(file_version));
@@ -1023,12 +1068,14 @@ static void llama_model_load_internal(
         fprintf(stderr, "%s: n_embd     = %u\n",   __func__, hparams.n_embd);
         fprintf(stderr, "%s: n_mult     = %u\n",   __func__, hparams.n_mult);
         fprintf(stderr, "%s: n_head     = %u\n",   __func__, hparams.n_head);
+        fprintf(stderr, "%s: n_head_kv  = %u\n",   __func__, hparams.n_head_kv);
         fprintf(stderr, "%s: n_layer    = %u\n",   __func__, hparams.n_layer);
-        fprintf(stderr, "%s: n_rot      = %u\n",   __func__, hparams.n_rot);
+        fprintf(stderr, "%s: n_rot      = %u\n",   __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
+        fprintf(stderr, "%s: n_gqa      = %u\n",   __func__, hparams.n_gqa());
+        fprintf(stderr, "%s: n_ff       = %u\n",   __func__, n_ff);
         fprintf(stderr, "%s: freq_base  = %.1f\n", __func__, hparams.rope_freq_base);
         fprintf(stderr, "%s: freq_scale = %g\n",   __func__, hparams.rope_freq_scale);
         fprintf(stderr, "%s: ftype      = %u (%s)\n", __func__, hparams.ftype, llama_ftype_name(hparams.ftype));
-        fprintf(stderr, "%s: n_ff       = %u\n",   __func__, n_ff);
         fprintf(stderr, "%s: model size = %s\n",   __func__, llama_model_type_name(model.type));
     }
 
@@ -1098,9 +1145,10 @@ static void llama_model_load_internal(
     size_t vram_weights = 0;
     size_t vram_scratch = 0;
     {
-        const uint32_t n_embd  = hparams.n_embd;
-        const uint32_t n_layer = hparams.n_layer;
-        const uint32_t n_vocab = hparams.n_vocab;
+        const uint32_t n_embd     = hparams.n_embd;
+        const uint32_t n_embd_gqa = hparams.n_embd_gqa();
+        const uint32_t n_layer    = hparams.n_layer;
+        const uint32_t n_vocab    = hparams.n_vocab;
 
         ml->ggml_ctx = ctx;
 
@@ -1148,16 +1196,16 @@ static void llama_model_load_internal(
 
             layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd}, backend);
 
-            layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, backend_split);
-            layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd}, backend_split);
-            layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd}, backend_split);
-            layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend_split);
+            layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd},     backend_split);
+            layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd_gqa}, backend_split);
+            layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd_gqa}, backend_split);
+            layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd},     backend_split);
 
             layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend);
 
-            layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd,   n_ff},   backend_split);
-            layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", {  n_ff,   n_embd}, backend_split);
-            layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd,   n_ff},   backend_split);
+            layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd,   n_ff}, backend_split);
+            layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", {  n_ff, n_embd}, backend_split);
+            layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd,   n_ff}, backend_split);
 
             if (backend == GGML_BACKEND_GPU) {
                 vram_weights +=
@@ -1281,6 +1329,7 @@ static bool llama_model_load(
         llama_vocab & vocab,
         int n_ctx,
         int n_batch,
+        int n_gqa,
         int n_gpu_layers,
         int main_gpu,
         const float * tensor_split,
@@ -1294,7 +1343,7 @@ static bool llama_model_load(
         llama_progress_callback progress_callback,
         void *progress_callback_user_data) {
     try {
-        llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type,
+        llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type,
                                   use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data);
         return true;
     } catch (const std::exception & err) {
@@ -1338,17 +1387,22 @@ static bool llama_eval_internal(
 
     LLAMA_ASSERT(!!kv_self.ctx);
 
-    const int n_embd       = hparams.n_embd;
-    const int n_layer      = hparams.n_layer;
-    const int n_ctx        = hparams.n_ctx;
-    const int n_head       = hparams.n_head;
-    const int n_vocab      = hparams.n_vocab;
-    const int n_rot        = hparams.n_embd/hparams.n_head;
-    const int n_gpu_layers = model.n_gpu_layers;
+    const int64_t n_embd      = hparams.n_embd;
+    const int64_t n_layer     = hparams.n_layer;
+    const int64_t n_ctx       = hparams.n_ctx;
+    const int64_t n_head      = hparams.n_head;
+    const int64_t n_head_kv   = hparams.n_head_kv;
+    const int64_t n_embd_head = hparams.n_embd_head();
+    const int64_t n_vocab     = hparams.n_vocab;
+    const int64_t n_embd_gqa  = hparams.n_embd_gqa();
+
+    LLAMA_ASSERT(n_embd_head == hparams.n_rot);
 
     const float freq_base  = hparams.rope_freq_base;
     const float freq_scale = hparams.rope_freq_scale;
 
+    const int n_gpu_layers = model.n_gpu_layers;
+
     auto & mem_per_token = lctx.mem_per_token;
     auto & buf_compute   = lctx.buf_compute;
 
@@ -1446,11 +1500,11 @@ static bool llama_eval_internal(
             offload_func_kq(tmpq);
             ggml_set_name(tmpq, "tmpq");
 
-            struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0, freq_base, freq_scale);
+            struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
             offload_func_kq(Kcur);
             ggml_set_name(Kcur, "Kcur");
 
-            struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0, freq_base, freq_scale);
+            struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N),    n_past, n_embd_head, 0, 0, freq_base, freq_scale);
             offload_func_kq(Qcur);
             ggml_set_name(Qcur, "Qcur");
 
@@ -1462,17 +1516,17 @@ static bool llama_eval_internal(
                 offload_func_v(tmpv);
                 ggml_set_name(tmpv, "tmpv");
 
-                struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd, N));
+                struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N));
                 offload_func_v(Vcur);
                 ggml_set_name(Vcur, "Vcur");
 
-                struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
+                struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past));
                 offload_func_kq(k);
                 ggml_set_name(k, "k");
 
-                struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
+                struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa,
                         (   n_ctx)*ggml_element_size(kv_self.v),
-                        (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
+                        (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v));
                 offload_func_v(v);
                 ggml_set_name(v, "v");
 
@@ -1491,8 +1545,8 @@ static bool llama_eval_internal(
             struct ggml_tensor * K =
                 ggml_permute(ctx0,
                         ggml_reshape_3d(ctx0,
-                            ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
-                            n_embd/n_head, n_head, n_past + N),
+                            ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd_gqa, il*n_ctx*ggml_element_size(kv_self.k)*n_embd_gqa),
+                            n_embd_head, n_head_kv, n_past + N),
                         0, 2, 1, 3);
             offload_func_kq(K);
             ggml_set_name(K, "K");
@@ -1502,9 +1556,9 @@ static bool llama_eval_internal(
             offload_func_kq(KQ);
             ggml_set_name(KQ, "KQ");
 
-            // KQ_scaled = KQ / sqrt(n_embd/n_head)
+            // KQ_scaled = KQ / sqrt(n_embd_head)
             struct ggml_tensor * KQ_scale = ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head));
-            ggml_set_name(KQ_scale, "1/sqrt(n_embd/n_head)");
+            ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
 
             // KQ_scaled shape [n_past + N, N, n_head, 1]
             struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
@@ -1524,10 +1578,10 @@ static bool llama_eval_internal(
             // split cached V into n_head heads
             struct ggml_tensor * V =
                 ggml_view_3d(ctx0, kv_self.v,
-                        n_past + N, n_embd/n_head, n_head,
+                        n_past + N, n_embd_head, n_head_kv,
                         n_ctx*ggml_element_size(kv_self.v),
-                        n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head,
-                        il*n_ctx*ggml_element_size(kv_self.v)*n_embd);
+                        n_ctx*ggml_element_size(kv_self.v)*n_embd_head,
+                        n_ctx*ggml_element_size(kv_self.v)*n_embd_gqa*il);
             offload_func_v(V);
             ggml_set_name(V, "V");
 
@@ -1539,7 +1593,7 @@ static bool llama_eval_internal(
             // make V contiguous in memory to speed up the matmul, however we waste time on the copy
             // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
             // is there a better way?
-            struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd/n_head, n_head));
+            struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head));
             struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
 #endif
 
@@ -2693,7 +2747,7 @@ struct llama_model * llama_load_model_from_file(
 
     ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
 
-    if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gpu_layers,
+    if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.n_gpu_layers,
                 params.main_gpu, params.tensor_split, params.rope_freq_base, params.rope_freq_scale,params.low_vram,
                 memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback,
                 params.progress_callback_user_data)) {
diff --git a/llama.h b/llama.h
index bbf28e68684cf60dc65f08f3eba1ff26565abcee..1089909a627f8b026570f6a70eac643c3cead902 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -83,11 +83,12 @@ extern "C" {
     typedef void (*llama_progress_callback)(float progress, void *ctx);
 
    struct llama_context_params {
-        uint32_t seed;                         // RNG seed, -1 for random
-        int32_t  n_ctx;                        // text context
-        int32_t  n_batch;                      // prompt processing batch size
-        int32_t  n_gpu_layers;                 // number of layers to store in VRAM
-        int32_t  main_gpu;                     // the GPU that is used for scratch and small tensors
+        uint32_t seed;         // RNG seed, -1 for random
+        int32_t  n_ctx;        // text context
+        int32_t  n_batch;      // prompt processing batch size
+        int32_t  n_gqa;        // grouped-query attention (TEMP - will be moved to model hparams)
+        int32_t  n_gpu_layers; // number of layers to store in VRAM
+        int32_t  main_gpu;     // the GPU that is used for scratch and small tensors
 
         const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)