]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add phi-2 + fix NeoX rope + ggml_mul_mat_set_prec (#4490)
authorEbey Abraham <redacted>
Mon, 18 Dec 2023 17:27:47 +0000 (17:27 +0000)
committerGitHub <redacted>
Mon, 18 Dec 2023 17:27:47 +0000 (19:27 +0200)
* phi2 implementation

* fix breaking change

* phi-2 : various fixes

* phi-2 : use layer norm eps

* py : whitespaces

* llama : fix meta KV override bug

* convert : phi don't add BOS token

* convert : revert "added_tokens_decoder" change

* phi-2 : scale Q instead of KQ for better precision

* ggml : fix NeoX rope to rotate just first n_dims

* cuda : less diff in the rope_neox kernel

* ggml : add ggml_mul_mat_set_prec

ggml-ci

* Update ggml-cuda.cu

Co-authored-by: slaren <redacted>
* Update ggml-cuda.cu

Co-authored-by: slaren <redacted>
* cuda : ggml_cuda_op_mul_mat_cublas support F32 precision

* cuda : remove oboslete comment

---------

Co-authored-by: Ebey Abraham <redacted>
Co-authored-by: Georgi Gerganov <redacted>
Co-authored-by: slaren <redacted>
convert-hf-to-gguf.py
ggml-cuda.cu
ggml-metal.metal
ggml.c
ggml.h
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
llama.cpp
tests/test-backend-ops.cpp

index e46a7813a78e98f42b0611179ccd3df312930896..e71a96c483313e205d9c789f20331ef3876b8858 100755 (executable)
@@ -182,6 +182,8 @@ class Model:
             return QwenModel
         if model_architecture == "MixtralForCausalLM":
             return MixtralModel
+        if model_architecture == "PhiForCausalLM":
+            return Phi2Model
         return Model
 
     def _is_model_safetensors(self) -> bool:
@@ -221,6 +223,8 @@ class Model:
             return gguf.MODEL_ARCH.QWEN
         if arch == "MixtralForCausalLM":
             return gguf.MODEL_ARCH.LLAMA
+        if arch == "PhiForCausalLM":
+            return gguf.MODEL_ARCH.PHI2
 
         raise NotImplementedError(f'Architecture "{arch}" not supported!')
 
@@ -980,6 +984,24 @@ class QwenModel(Model):
             print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
             self.gguf_writer.add_tensor(new_name, data)
 
+
+class Phi2Model(Model):
+    def set_gguf_parameters(self):
+        block_count = self.hparams["n_layer"]
+
+        self.gguf_writer.add_name("Phi2")
+        self.gguf_writer.add_context_length(self.hparams["n_positions"])
+        self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
+        self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
+        self.gguf_writer.add_block_count(block_count)
+        self.gguf_writer.add_head_count(self.hparams["n_head"])
+        self.gguf_writer.add_head_count_kv(self.hparams["n_head"])
+        self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
+        self.gguf_writer.add_rope_dimension_count(self.hparams["rotary_dim"])
+        self.gguf_writer.add_file_type(self.ftype)
+        self.gguf_writer.add_add_bos_token(False)
+
+
 ###### CONVERSION LOGIC ######
 
 
index 0a63c1ecf3bfb1032db5148ea0abce655c528362..d0f3d803459364dd88c36bdaabcf50418759bb84 100644 (file)
@@ -4998,7 +4998,16 @@ static __global__ void rope_neox(
     const int ib = col / n_dims;
     const int ic = col % n_dims;
 
-    const int i = row*ncols + ib*n_dims + ic/2;
+    if (ib > 0) {
+        const int i = row*ncols + ib*n_dims + ic;
+
+        dst[i + 0] = x[i + 0];
+        dst[i + 1] = x[i + 1];
+
+        return;
+    }
+
+    const int i  = row*ncols + ib*n_dims + ic/2;
     const int i2 = row/p_delta_rows;
 
     float cur_rot = inv_ndims * ic - ib;
@@ -7057,6 +7066,7 @@ inline void ggml_cuda_op_upscale(
 
     (void) src1;
     (void) dst;
+    (void) src1_dd;
 }
 
 inline void ggml_cuda_op_pad(
@@ -7073,6 +7083,7 @@ inline void ggml_cuda_op_pad(
 
     (void) src1;
     (void) dst;
+    (void) src1_dd;
 }
 
 inline void ggml_cuda_op_rms_norm(
@@ -7376,7 +7387,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
 
     const int compute_capability = g_compute_capabilities[id];
 
-    if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
+    if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
         // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
         half * src0_as_f16 = nullptr;
         size_t src0_as = 0;
@@ -8300,27 +8311,27 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
 }
 
 static __global__ void k_compute_batched_ptrs(
-        const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
+        const half * src0_as_f16, const half * src1_as_f16, char * dst,
         const void ** ptrs_src, void ** ptrs_dst,
-        int ne12, int ne13,
-        int ne23,
-        int nb02, int nb03,
-        int nb12, int nb13,
-        int nb2, int nb3,
-        int r2, int r3) {
-    int i13 = blockIdx.x * blockDim.x + threadIdx.x;
-    int i12 = blockIdx.y * blockDim.y + threadIdx.y;
+        int64_t ne12, int64_t ne13,
+        int64_t ne23,
+        size_t  nb02, size_t  nb03,
+        size_t  nb12, size_t  nb13,
+        size_t  nbd2, size_t  nbd3,
+        int64_t r2,   int64_t r3) {
+    int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;
+    int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;
 
     if (i13 >= ne13 || i12 >= ne12) {
         return;
     }
 
-    int i03 = i13 / r3;
-    int i02 = i12 / r2;
+    int64_t i03 = i13 / r3;
+    int64_t i02 = i12 / r2;
 
     ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02   + i03*nb03;
     ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
-    ptrs_dst[0*ne23 + i12 + i13*ne12] = (      char *)     dst_f16 + i12* nb2/2 + i13* nb3/2;
+    ptrs_dst[0*ne23 + i12 + i13*ne12] = (      char *)         dst + i12*nbd2   + i13*nbd3;
 }
 
 static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -8376,7 +8387,41 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
     to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
 
     size_t dst_as = 0;
-    half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
+
+    half * dst_f16 = nullptr;
+    char * dst_t   = nullptr;
+
+    cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
+    cudaDataType_t      cu_data_type    = CUDA_R_16F;
+
+    // dst strides
+    size_t nbd2 = dst->nb[2];
+    size_t nbd3 = dst->nb[3];
+
+    const half  alpha_f16 = 1.0f;
+    const half  beta_f16  = 0.0f;
+
+    const float alpha_f32 = 1.0f;
+    const float beta_f32  = 0.0f;
+
+    const void * alpha = &alpha_f16;
+    const void * beta  = &beta_f16;
+
+    if (dst->op_params[0] == GGML_PREC_DEFAULT) {
+        dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
+        dst_t   = (char *) dst_f16;
+
+        nbd2 /= sizeof(float) / sizeof(half);
+        nbd3 /= sizeof(float) / sizeof(half);
+    } else {
+        dst_t = (char *) dst_ddf;
+
+        cu_compute_type = CUBLAS_COMPUTE_32F;
+        cu_data_type    = CUDA_R_32F;
+
+        alpha = &alpha_f32;
+        beta  = &beta_f32;
+    }
 
     GGML_ASSERT(ne12 % ne02 == 0);
     GGML_ASSERT(ne13 % ne03 == 0);
@@ -8385,9 +8430,6 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
     const int64_t r2 = ne12/ne02;
     const int64_t r3 = ne13/ne03;
 
-    const half alpha_f16 = 1.0f;
-    const half beta_f16  = 0.0f;
-
 #if 0
     // use cublasGemmEx
     {
@@ -8397,12 +8439,12 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
                 int i02 = i12 / r2;
 
                 CUBLAS_CHECK(
-                        cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
+                        cublasGemmEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
                             ne01, ne11, ne10,
-                            &alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2]   + i03*src0->nb[3]  , CUDA_R_16F, nb01/sizeof(half),
-                                        (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
-                            &beta_f16,  (      char *)     dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
-                            CUBLAS_COMPUTE_16F,
+                            alpha, (const char *) src0_as_f16 + i02*src0->nb[2]   + i03*src0->nb[3]  , CUDA_R_16F,   nb01/sizeof(half),
+                                   (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F,   nb11/sizeof(float),
+                            beta,  (      char *)       dst_t + i12*nbd2          + i13*nbd3,          cu_data_type, ne01,
+                            cu_compute_type,
                             CUBLAS_GEMM_DEFAULT_TENSOR_OP));
             }
         }
@@ -8414,11 +8456,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
         CUBLAS_CHECK(
         cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
                 ne01, ne11, ne10,
-                &alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half),  src0->nb[2]/sizeof(half),  // strideA
-                            (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
-                &beta_f16,  (      char *)     dst_f16, CUDA_R_16F, ne01,                dst->nb[2]/sizeof(float), // strideC
+                alpha, (const char *) src0_as_f16, CUDA_R_16F,   nb01/sizeof(half),  src0->nb[2]/sizeof(half),  // strideA
+                       (const char *) src1_as_f16, CUDA_R_16F,   nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
+                beta,  (      char *)       dst_t, cu_data_type, ne01,                dst->nb[2]/sizeof(float), // strideC
                 ne12*ne13,
-                CUBLAS_COMPUTE_16F,
+                cu_compute_type,
                 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
     } else {
         // use cublasGemmBatchedEx
@@ -8435,24 +8477,24 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
 
         dim3 block_dims(ne13, ne12);
         k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
-                src0_as_f16, src1_as_f16, dst_f16,
+                src0_as_f16, src1_as_f16, dst_t,
                 ptrs_src, ptrs_dst,
                 ne12, ne13,
                 ne23,
                 nb02, nb03,
                 nb12, nb13,
-                dst->nb[2], dst->nb[3],
+                nbd2, nbd3,
                 r2, r3);
         CUDA_CHECK(cudaGetLastError());
 
         CUBLAS_CHECK(
         cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
                 ne01, ne11, ne10,
-                &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
-                            (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
-                &beta_f16,  (      void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
+                alpha, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F,   nb01/sizeof(half),
+                       (const void **) (ptrs_src + 1*ne23), CUDA_R_16F,   nb11/sizeof(float),
+                beta,  (      void **) (ptrs_dst + 0*ne23), cu_data_type, ne01,
                 ne23,
-                CUBLAS_COMPUTE_16F,
+                cu_compute_type,
                 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
 
         if (ptrs_src_s != 0) {
@@ -8464,11 +8506,14 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
     }
 #endif
 
-    const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
-    to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
+    if (dst->op_params[0] == GGML_PREC_DEFAULT) {
+        const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
+        to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
+
+        ggml_cuda_pool_free(dst_f16, dst_as);
+    }
 
     ggml_cuda_pool_free(src1_as_f16, src1_as);
-    ggml_cuda_pool_free(dst_f16, dst_as);
 }
 
 static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
index fe0ada445a2d45c4f295021565a8e9247c9b2dfc..d5b54e112ea37ebcfe61d1ed38ef659b6f240f83 100644 (file)
@@ -1702,8 +1702,9 @@ kernel void kernel_rope(
             dst_data[1] = x0*sin_theta + x1*cos_theta;
         }
     } else {
-        for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
-            for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
+        for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
+            if (ic < n_dims) {
+                const int64_t ib = 0;
 
                 // simplified from `(ib * n_dims + ic) * inv_ndims`
                 const float cur_rot = inv_ndims*ic - ib;
@@ -1722,6 +1723,14 @@ kernel void kernel_rope(
 
                 dst_data[0]        = x0*cos_theta - x1*sin_theta;
                 dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+            } else {
+                const int64_t i0 = ic;
+
+                device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                dst_data[0] = src[0];
+                dst_data[1] = src[1];
             }
         }
     }
diff --git a/ggml.c b/ggml.c
index ad546a7314a146680b592df39095847c7db94973..6da65bd9286307ba822a384998b3b2404a3b7ecd 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -4098,6 +4098,14 @@ struct ggml_tensor * ggml_mul_mat(
     return result;
 }
 
+void ggml_mul_mat_set_prec(
+        struct ggml_tensor * a,
+        enum ggml_prec       prec) {
+    const int32_t prec_i32 = (int32_t) prec;
+
+    ggml_set_op_params_i32(a, 0, prec_i32);
+}
+
 // ggml_mul_mat_id
 
 struct ggml_tensor * ggml_mul_mat_id(
@@ -9168,6 +9176,8 @@ static void ggml_compute_forward_norm_f32(
     float eps;
     memcpy(&eps, dst->op_params, sizeof(float));
 
+    GGML_ASSERT(eps > 0.0f);
+
     // TODO: optimize
     for (int64_t i03 = 0; i03 < ne03; i03++) {
         for (int64_t i02 = 0; i02 < ne02; i02++) {
@@ -9237,6 +9247,8 @@ static void ggml_compute_forward_rms_norm_f32(
     float eps;
     memcpy(&eps, dst->op_params, sizeof(float));
 
+    GGML_ASSERT(eps > 0.0f);
+
     // TODO: optimize
     for (int64_t i03 = 0; i03 < ne03; i03++) {
         for (int64_t i02 = 0; i02 < ne02; i02++) {
@@ -11562,10 +11574,13 @@ static void ggml_compute_forward_rope_f32(
                     }
                 } else {
                     // TODO: this might be wrong for ne0 != n_dims - need double check
-                    // ref:  https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
+                    //       it seems we have to rope just the first n_dims elements and do nothing with the rest
+                    // ref:  https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
                     theta_base *= freq_scale;
-                    for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
-                        for (int64_t ic = 0; ic < n_dims; ic += 2) {
+                    for (int64_t ic = 0; ic < ne0; ic += 2) {
+                        if (ic < n_dims) {
+                            const int64_t ib = 0;
+
                             // simplified from `(ib * n_dims + ic) * inv_ndims`
                             float cur_rot = inv_ndims * ic - ib;
 
@@ -11588,6 +11603,14 @@ static void ggml_compute_forward_rope_f32(
 
                             dst_data[0]        = x0*cos_theta - x1*sin_theta;
                             dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+                        } else {
+                            const int64_t i0 = ic;
+
+                            const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                                  float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                            dst_data[0] = src[0];
+                            dst_data[1] = src[1];
                         }
                     }
                 }
@@ -11715,10 +11738,13 @@ static void ggml_compute_forward_rope_f16(
                     }
                 } else {
                     // TODO: this might be wrong for ne0 != n_dims - need double check
-                    // ref:  https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
+                    //       it seems we have to rope just the first n_dims elements and do nothing with the rest
+                    // ref:  https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
                     theta_base *= freq_scale;
-                    for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
-                        for (int64_t ic = 0; ic < n_dims; ic += 2) {
+                    for (int64_t ic = 0; ic < ne0; ic += 2) {
+                        if (ic < n_dims) {
+                            const int64_t ib = 0;
+
                             // simplified from `(ib * n_dims + ic) * inv_ndims`
                             float cur_rot = inv_ndims * ic - ib;
 
@@ -11741,6 +11767,14 @@ static void ggml_compute_forward_rope_f16(
 
                             dst_data[0]        = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
                             dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+                        } else {
+                            const int64_t i0 = ic;
+
+                            const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                                  ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                            dst_data[0] = src[0];
+                            dst_data[1] = src[1];
                         }
                     }
                 }
diff --git a/ggml.h b/ggml.h
index 68f7833b623436695b94ef978118772025d22ff6..f1003984fa0eaed1f0807417cb233e36a7ab411e 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -343,6 +343,12 @@ extern "C" {
         GGML_TYPE_COUNT,
     };
 
+    // precision
+    enum ggml_prec {
+        GGML_PREC_DEFAULT,
+        GGML_PREC_F32,
+    };
+
     enum ggml_backend_type {
         GGML_BACKEND_CPU = 0,
         GGML_BACKEND_GPU = 10,
@@ -1057,6 +1063,12 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    // change the precision of a matrix multiplication
+    // set to GGML_PREC_F32 for higher precision (useful for phi-2)
+    GGML_API void ggml_mul_mat_set_prec(
+            struct ggml_tensor * a,
+            enum ggml_prec       prec);
+
     // indirect matrix multiplication
     //  ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
     GGML_API struct ggml_tensor * ggml_mul_mat_id(
index 12133882be2c4065f20091be9fe0ea8c3ae73adc..390dca049ebee5bbb240cf91be75e2c2d2b7118a 100644 (file)
@@ -95,6 +95,7 @@ class MODEL_ARCH(IntEnum):
     BLOOM     = auto()
     STABLELM  = auto()
     QWEN      = auto()
+    PHI2      = auto()
 
 
 class MODEL_TENSOR(IntEnum):
@@ -140,6 +141,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.BLOOM:          "bloom",
     MODEL_ARCH.STABLELM:       "stablelm",
     MODEL_ARCH.QWEN:           "qwen",
+    MODEL_ARCH.PHI2:           "phi2",
 }
 
 TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -350,6 +352,17 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
     MODEL_ARCH.GPT2: [
         # TODO
     ],
+    MODEL_ARCH.PHI2: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_QKV,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ]
     # TODO
 }
 
index 0115ea1c605b1f796ba870fd00522a0bf7414a91..6fcbdbc1c0d4cd43bb0c1abce8ac2f0196107c56 100644 (file)
@@ -17,6 +17,7 @@ class TensorNameMap:
             "tok_embeddings",                            # llama-pth
             "embeddings.word_embeddings",                # bert
             "language_model.embedding.word_embeddings",  # persimmon
+            "transformer.embd.wte",                      # phi2
         ),
 
         # Token type embeddings
@@ -41,6 +42,7 @@ class TensorNameMap:
             "lm_head",                   # gpt2 mpt falcon llama-hf baichuan qwen
             "output",                    # llama-pth bloom
             "word_embeddings_for_head",  # persimmon
+            "lm_head.linear",            # phi2
         ),
 
         # Output norm
@@ -53,6 +55,7 @@ class TensorNameMap:
             "transformer.norm_f",                      # mpt
             "ln_f",                                    # refact bloom qwen
             "language_model.encoder.final_layernorm",  # persimmon
+            "lm_head.ln",                              # phi2
         ),
 
         # Rope frequencies
@@ -75,6 +78,7 @@ class TensorNameMap:
             "encoder.layer.{bid}.attention.output.LayerNorm",       # bert
             "language_model.encoder.layers.{bid}.input_layernorm",  # persimmon
             "model.layers.{bid}.ln1",                               # yi
+            "transformer.h.{bid}.ln",                               # phi2
         ),
 
         # Attention norm 2
@@ -90,6 +94,7 @@ class TensorNameMap:
             "transformer.h.{bid}.self_attention.query_key_value",                  # falcon
             "h.{bid}.self_attention.query_key_value",                              # bloom
             "language_model.encoder.layers.{bid}.self_attention.query_key_value",  # persimmon
+            "transformer.h.{bid}.mixer.Wqkv",                                      # phi2
         ),
 
         # Attention query
@@ -128,6 +133,7 @@ class TensorNameMap:
             "encoder.layer.{bid}.attention.output.dense",                # bert
             "transformer.h.{bid}.attn.out_proj",                         # gpt-j
             "language_model.encoder.layers.{bid}.self_attention.dense",  # persimmon
+            "transformer.h.{bid}.mixer.out_proj",                        # phi2
         ),
 
         # Rotary embeddings
@@ -167,6 +173,7 @@ class TensorNameMap:
             "transformer.h.{bid}.mlp.fc_in",                          # gpt-j
             "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h",  # persimmon
             "transformer.h.{bid}.mlp.w1",                             # qwen
+            "transformer.h.{bid}.mlp.fc1",                            # phi2
         ),
 
         MODEL_TENSOR.FFN_UP_EXP: (
@@ -198,6 +205,7 @@ class TensorNameMap:
             "encoder.layer.{bid}.output.dense",                       # bert
             "transformer.h.{bid}.mlp.fc_out",                         # gpt-j
             "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h",  # persimmon
+            "transformer.h.{bid}.mlp.fc2",                            # phi2
         ),
 
         MODEL_TENSOR.FFN_DOWN_EXP: (
index 99facbf77a1e2f8388447b61ca94eb6051e29141..edd2910b3ad29ec28d5744709fe50f52f229823c 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -195,6 +195,7 @@ enum llm_arch {
     LLM_ARCH_BLOOM,
     LLM_ARCH_STABLELM,
     LLM_ARCH_QWEN,
+    LLM_ARCH_PHI2,
     LLM_ARCH_UNKNOWN,
 };
 
@@ -212,6 +213,7 @@ static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
     { LLM_ARCH_BLOOM,           "bloom"     },
     { LLM_ARCH_STABLELM,        "stablelm"  },
     { LLM_ARCH_QWEN,            "qwen"      },
+    { LLM_ARCH_PHI2,            "phi2"      },
 };
 
 enum llm_kv {
@@ -550,6 +552,19 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_PHI2,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
 
     {
         LLM_ARCH_UNKNOWN,
@@ -1420,6 +1435,7 @@ struct llama_model {
     struct ggml_tensor * output_norm;
     struct ggml_tensor * output_norm_b;
     struct ggml_tensor * output;
+    struct ggml_tensor * output_b;
 
     std::vector<llama_layer> layers;
 
@@ -2635,6 +2651,15 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_PHI2:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
+
+                switch (hparams.n_layer) {
+                    case 32: model.type = e_model::MODEL_3B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
 
         default: (void)0;
     }
@@ -2987,7 +3012,7 @@ static void llm_load_tensors(
 
     (void) main_gpu;
 
-    enum ggml_backend_type llama_backend_offload = GGML_BACKEND_CPU;
+    enum ggml_backend_type llama_backend_offload       = GGML_BACKEND_CPU;
     enum ggml_backend_type llama_backend_offload_split = GGML_BACKEND_CPU;
 
 #ifdef GGML_USE_CUBLAS
@@ -3630,7 +3655,73 @@ static void llm_load_tensors(
                         }
                     }
                 } break;
+            case LLM_ARCH_PHI2:
+                {
+                    model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
+
+                    // output
+                    {
+                        ggml_backend_type backend_norm;
+                        ggml_backend_type backend_output;
+
+                        if (n_gpu_layers > int(n_layer)) {
+                            backend_norm   = llama_backend_offload;
+                            backend_output = llama_backend_offload;
+                        } else {
+                            backend_norm   = GGML_BACKEND_CPU;
+                            backend_output = GGML_BACKEND_CPU;
+                        }
+
+                        model.output_norm   = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd},          backend_norm);
+                        model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd},          backend_norm);
+                        model.output        = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, backend_output);
+                        model.output_b      = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT,      "bias"),   {n_vocab},         backend_output);
+
+                        if (backend_norm == GGML_BACKEND_GPU) {
+                            vram_weights += ggml_nbytes(model.output_norm);
+                            vram_weights += ggml_nbytes(model.output_norm_b);
+                            vram_weights += ggml_nbytes(model.output);
+                            vram_weights += ggml_nbytes(model.output_b);
+                        }
+                    }
+
+                    const uint32_t n_ff = hparams.n_ff;
+
+                    const int i_gpu_start = n_layer - n_gpu_layers;
 
+                    model.layers.resize(n_layer);
+
+                    for (uint32_t i = 0; i < n_layer; ++i) {
+                        const ggml_backend_type backend       = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT
+                        const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm   = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
+                        layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, backend);
+
+                        layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
+                        layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa},         backend);
+
+                        layer.wo   = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
+                        layer.bo   = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd},         backend);
+
+                        layer.ffn_down   = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split);
+                        layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd},       backend);
+
+                        layer.ffn_up   = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
+                        layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i),   {n_ff},         backend);
+
+                        if (backend == GGML_BACKEND_GPU) {
+                            vram_weights +=
+                                ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) +
+                                ggml_nbytes(layer.wqkv)      + ggml_nbytes(layer.bqkv)        +
+                                ggml_nbytes(layer.wo)        + ggml_nbytes(layer.bo)          +
+                                ggml_nbytes(layer.ffn_up)    + ggml_nbytes(layer.ffn_up_b)    +
+                                ggml_nbytes(layer.ffn_down)  + ggml_nbytes(layer.ffn_down_b);
+                        }
+                    }
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -3991,6 +4082,7 @@ static struct ggml_tensor * llm_build_ffn(
 // if max_alibi_bias > 0 then apply ALiBi
 static struct ggml_tensor * llm_build_kqv(
         struct ggml_context * ctx,
+          const llama_model & model,
         const llama_hparams & hparams,
        const llama_kv_cache & kv,
          struct ggml_tensor * wo,
@@ -4002,6 +4094,7 @@ static struct ggml_tensor * llm_build_kqv(
                     int32_t   n_tokens,
                     int32_t   n_kv,
                     float     max_alibi_bias,
+                    float     scale,
          const llm_build_cb & cb,
                     int       il) {
     const int64_t n_embd      = hparams.n_embd;
@@ -4024,6 +4117,12 @@ static struct ggml_tensor * llm_build_kqv(
     struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
     cb(kq, "kq", il);
 
+    if (model.arch == LLM_ARCH_PHI2) {
+        // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
+        // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
+        ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
+    }
+
     if (max_alibi_bias > 0.0f) {
         // temporary branch until we figure out how to handle ggml_alibi through ggml_add
         kq = ggml_scale(ctx, kq, kq_scale);
@@ -4043,7 +4142,7 @@ static struct ggml_tensor * llm_build_kqv(
         kq = ggml_soft_max(ctx, kq);
         cb(kq, "kq_soft_max", il);
     } else {
-        kq = ggml_soft_max_ext(ctx, kq, kq_mask, 1.0f/sqrtf(float(n_embd_head)));
+        kq = ggml_soft_max_ext(ctx, kq, kq_mask, scale);
         cb(kq, "kq_soft_max_ext", il);
     }
 
@@ -4250,9 +4349,9 @@ struct llm_build_context {
 
                 llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
 
-                cur = llm_build_kqv(ctx0, hparams, kv_self,
+                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
+                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -4433,9 +4532,9 @@ struct llm_build_context {
                 // apply ALiBi for 13B model
                 const float max_alibi_bias = model.type == MODEL_13B ? 8.0f : -1.0f;
 
-                cur = llm_build_kqv(ctx0, hparams, kv_self,
+                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
                         model.layers[il].wo, NULL,
-                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, cb, il);
+                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -4557,9 +4656,9 @@ struct llm_build_context {
 
                 llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
 
-                cur = llm_build_kqv(ctx0, hparams, kv_self,
+                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
                         model.layers[il].wo, NULL,
-                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
+                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -4657,9 +4756,9 @@ struct llm_build_context {
 
                 llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
 
-                cur = llm_build_kqv(ctx0, hparams, kv_self,
+                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
+                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -4866,9 +4965,9 @@ struct llm_build_context {
                 llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
 
                 // TODO: not tested, could be broken
-                cur = llm_build_kqv(ctx0, hparams, kv_self,
+                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
                         model.layers[il].wo, model.layers[il].bo,
-                        Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
+                        Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -4957,9 +5056,9 @@ struct llm_build_context {
 
                 llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
 
-                cur = llm_build_kqv(ctx0, hparams, kv_self,
+                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
                         model.layers[il].wo, NULL,
-                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, cb, il);
+                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5054,9 +5153,9 @@ struct llm_build_context {
 
                 llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
 
-                cur = llm_build_kqv(ctx0, hparams, kv_self,
+                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, cb, il);
+                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5148,9 +5247,9 @@ struct llm_build_context {
 
                 llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
 
-                cur = llm_build_kqv(ctx0, hparams, kv_self,
+                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
                         model.layers[il].wo, NULL,
-                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, cb, il);
+                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5261,9 +5360,9 @@ struct llm_build_context {
 
                 llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
 
-                cur = llm_build_kqv(ctx0, hparams, kv_self,
+                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
                         model.layers[il].wo, NULL,
-                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
+                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5320,15 +5419,15 @@ struct llm_build_context {
         cb(inpL, "inp_embd", -1);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos= ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
         cb(inp_pos, "inp_pos", -1);
 
         // KQ_scale
-        struct ggml_tensor * KQ_scale= ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+        struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
         cb(KQ_scale, "KQ_scale", -1);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask= ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
         cb(KQ_mask, "KQ_mask", -1);
 
         // shift the entire K-cache if needed
@@ -5378,9 +5477,9 @@ struct llm_build_context {
 
                 llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
 
-                cur = llm_build_kqv(ctx0, hparams, kv_self,
+                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
                         model.layers[il].wo, NULL,
-                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
+                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5422,6 +5521,122 @@ struct llm_build_context {
 
         ggml_build_forward_expand(gf, cur);
 
+        return gf;
+    }
+    struct ggml_cgraph * build_phi2() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * attn_norm_output;
+        struct ggml_tensor * ffn_output;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+        cb(inpL, "inp_embd", -1);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+        cb(inp_pos, "inp_pos", -1);
+
+        // Q_scale
+        struct ggml_tensor * Q_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+        cb(Q_scale, "Q_scale", -1);
+
+        // KQ_scale
+        struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+        cb(KQ_scale, "KQ_scale", -1);
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        cb(KQ_mask, "KQ_mask", -1);
+
+        // shift the entire K-cache if needed
+        if (do_rope_shift) {
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+        }
+
+        for (int il = 0; il < n_layer; ++il) {
+            attn_norm_output = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, cb, il);
+            cb(attn_norm_output, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                cb(cur, "bqkv", il);
+
+                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = ggml_rope_custom(
+                    ctx0, Qcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
+                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Qcur = ggml_scale(ctx0, Qcur, Q_scale);
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_custom(
+                    ctx0, Kcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
+                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
+
+                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f, cb, il);
+                cb(cur, "kqv_out", il);
+            }
+
+            // FF
+            {
+                ffn_output = llm_build_ffn(ctx0, attn_norm_output,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
+                        NULL,                      NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
+                cb(ffn_output, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_output);
+            cb(cur, "l_out", il);
+
+            cur = ggml_add(ctx0, cur, inpL);
+            cb(cur, "l_out", il);
+
+            inpL = cur;
+        }
+
+        cur = llm_build_norm(ctx0, inpL, hparams,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cb(cur, "result_output_no_bias", -1);
+
+        cur = ggml_add(ctx0, cur, model.output_b);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
         return gf;
     }
 };
@@ -5437,7 +5652,7 @@ enum llm_offload_func_e {
     OFFLOAD_FUNC_FRC, // force offload
     OFFLOAD_FUNC_KQV,
     OFFLOAD_FUNC_NR,
-    OFFLOAD_FUNC_EMB,
+    OFFLOAD_FUNC_EMB, // embeddings
     OFFLOAD_FUNC_OUT,
 };
 
@@ -5522,6 +5737,7 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
     { "pos_embd",                   OFFLOAD_FUNC_NR  },
 
     { "inp_pos",                    OFFLOAD_FUNC_FRC }, // this is often used for KQ ops (e.g. rope)
+    { "Q_scale",                    OFFLOAD_FUNC_FRC },
     { "KQ_scale",                   OFFLOAD_FUNC_FRC },
     { "KQ_mask",                    OFFLOAD_FUNC_FRC },
     { "K_shift",                    OFFLOAD_FUNC_FRC },
@@ -5606,6 +5822,7 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
     { "l_out",                      OFFLOAD_FUNC     },
 
     { "result_norm",                OFFLOAD_FUNC_EMB },
+    { "result_output_no_bias",      OFFLOAD_FUNC_EMB },
     { "result_output",              OFFLOAD_FUNC_OUT },
 };
 
@@ -5623,6 +5840,7 @@ static struct ggml_cgraph * llama_build_graph(
     bool alloc_inp_tokens   = false;
     bool alloc_inp_embd     = false;
     bool alloc_inp_pos      = false;
+    bool alloc_inp_Q_scale  = false;
     bool alloc_inp_KQ_scale = false;
     bool alloc_inp_KQ_mask  = false;
     bool alloc_inp_K_shift  = false;
@@ -5690,7 +5908,7 @@ static struct ggml_cgraph * llama_build_graph(
             alloc_inp_pos = true;
         }
 
-        if (!alloc_inp_KQ_scale && strcmp(name, "KQ_scale") == 0) {
+        if (!alloc_inp_Q_scale && strcmp(name, "Q_scale") == 0) {
             ggml_allocr_alloc(lctx.alloc, cur);
 
             if (!ggml_allocr_is_measure(lctx.alloc)) {
@@ -5698,6 +5916,23 @@ static struct ggml_cgraph * llama_build_graph(
                 ggml_set_f32(cur, 1.0f/sqrtf(float(n_embd_head)));
             }
 
+            alloc_inp_Q_scale = true;
+        }
+
+        if (!alloc_inp_KQ_scale && strcmp(name, "KQ_scale") == 0) {
+            ggml_allocr_alloc(lctx.alloc, cur);
+
+            if (!ggml_allocr_is_measure(lctx.alloc)) {
+                const int64_t n_embd_head = model.hparams.n_embd_head();
+                if (model.arch == LLM_ARCH_PHI2) {
+                    // with phi2, we scale the Q to avoid precision issues
+                    // ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66
+                    ggml_set_f32(cur, 1.0f);
+                } else {
+                    ggml_set_f32(cur, 1.0f/sqrtf(float(n_embd_head)));
+                }
+            }
+
             alloc_inp_KQ_scale = true;
         }
 
@@ -5922,6 +6157,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_qwen();
             } break;
+        case LLM_ARCH_PHI2:
+            {
+                result = llm.build_phi2();
+            } break;
         default:
             GGML_ASSERT(false);
     }
@@ -6055,12 +6294,16 @@ static int llama_decode_internal(
 
     ggml_allocr_alloc_graph(lctx.alloc, gf);
 
-    struct ggml_tensor * res        = gf->nodes[gf->n_nodes - 1];
-    struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
-
-    GGML_ASSERT(strcmp(res->name,        "result_output") == 0);
-    GGML_ASSERT(strcmp(embeddings->name, "result_norm")   == 0);
+    // the output is always the last tensor in the graph
+    struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
+    GGML_ASSERT(strcmp(res->name, "result_output") == 0);
 
+    // the embeddings could be the second to last tensor, or the third to last tensor
+    struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
+    if (strcmp(embeddings->name, "result_norm") != 0) {
+        embeddings = gf->nodes[gf->n_nodes - 3];
+        GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
+    }
 
 #ifdef GGML_USE_CUBLAS
     for (int i = 0; i < gf->n_leafs; i++) {
index df2c3fb6e031fc745c785d8963785a7c8e16939c..f04b9438a6194d2b934bf380f9ee188d102d39bf 100644 (file)
@@ -1555,6 +1555,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         test_cases.emplace_back(new test_rope(type, { 64,   8, 10, 1},  64, 2, 512)); // neox (falcon 40B)
         test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1},  64, 2, 512)); // neox (falcon 40B)
         test_cases.emplace_back(new test_rope(type, { 80,  32, 10, 1},  20, 2, 512)); // neox (stablelm)
+        test_cases.emplace_back(new test_rope(type, { 80,  32, 10, 1},  32, 2, 512)); // neox (phi-2)
     }
 
     test_cases.emplace_back(new test_alibi());