const int ib = col / n_dims;
const int ic = col % n_dims;
- const int i = row*ncols + ib*n_dims + ic/2;
+ if (ib > 0) {
+ const int i = row*ncols + ib*n_dims + ic;
+
+ dst[i + 0] = x[i + 0];
+ dst[i + 1] = x[i + 1];
+
+ return;
+ }
+
+ const int i = row*ncols + ib*n_dims + ic/2;
const int i2 = row/p_delta_rows;
float cur_rot = inv_ndims * ic - ib;
(void) src1;
(void) dst;
+ (void) src1_dd;
}
inline void ggml_cuda_op_pad(
(void) src1;
(void) dst;
+ (void) src1_dd;
}
inline void ggml_cuda_op_rms_norm(
const int compute_capability = g_compute_capabilities[id];
- if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
+ if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
half * src0_as_f16 = nullptr;
size_t src0_as = 0;
}
static __global__ void k_compute_batched_ptrs(
- const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
+ const half * src0_as_f16, const half * src1_as_f16, char * dst,
const void ** ptrs_src, void ** ptrs_dst,
- int ne12, int ne13,
- int ne23,
- int nb02, int nb03,
- int nb12, int nb13,
- int nb2, int nb3,
- int r2, int r3) {
- int i13 = blockIdx.x * blockDim.x + threadIdx.x;
- int i12 = blockIdx.y * blockDim.y + threadIdx.y;
+ int64_t ne12, int64_t ne13,
+ int64_t ne23,
+ size_t nb02, size_t nb03,
+ size_t nb12, size_t nb13,
+ size_t nbd2, size_t nbd3,
+ int64_t r2, int64_t r3) {
+ int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;
+ int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;
if (i13 >= ne13 || i12 >= ne12) {
return;
}
- int i03 = i13 / r3;
- int i02 = i12 / r2;
+ int64_t i03 = i13 / r3;
+ int64_t i02 = i12 / r2;
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
- ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
+ ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
}
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
size_t dst_as = 0;
- half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
+
+ half * dst_f16 = nullptr;
+ char * dst_t = nullptr;
+
+ cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
+ cudaDataType_t cu_data_type = CUDA_R_16F;
+
+ // dst strides
+ size_t nbd2 = dst->nb[2];
+ size_t nbd3 = dst->nb[3];
+
+ const half alpha_f16 = 1.0f;
+ const half beta_f16 = 0.0f;
+
+ const float alpha_f32 = 1.0f;
+ const float beta_f32 = 0.0f;
+
+ const void * alpha = &alpha_f16;
+ const void * beta = &beta_f16;
+
+ if (dst->op_params[0] == GGML_PREC_DEFAULT) {
+ dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
+ dst_t = (char *) dst_f16;
+
+ nbd2 /= sizeof(float) / sizeof(half);
+ nbd3 /= sizeof(float) / sizeof(half);
+ } else {
+ dst_t = (char *) dst_ddf;
+
+ cu_compute_type = CUBLAS_COMPUTE_32F;
+ cu_data_type = CUDA_R_32F;
+
+ alpha = &alpha_f32;
+ beta = &beta_f32;
+ }
GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;
- const half alpha_f16 = 1.0f;
- const half beta_f16 = 0.0f;
-
#if 0
// use cublasGemmEx
{
int i02 = i12 / r2;
CUBLAS_CHECK(
- cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
+ cublasGemmEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
- &alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
- (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
- &beta_f16, ( char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
- CUBLAS_COMPUTE_16F,
+ alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
+ (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
+ beta, ( char *) dst_t + i12*nbd2 + i13*nbd3, cu_data_type, ne01,
+ cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
}
CUBLAS_CHECK(
cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
- &alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
- (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
- &beta_f16, ( char *) dst_f16, CUDA_R_16F, ne01, dst->nb[2]/sizeof(float), // strideC
+ alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
+ (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
+ beta, ( char *) dst_t, cu_data_type, ne01, dst->nb[2]/sizeof(float), // strideC
ne12*ne13,
- CUBLAS_COMPUTE_16F,
+ cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
// use cublasGemmBatchedEx
dim3 block_dims(ne13, ne12);
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
- src0_as_f16, src1_as_f16, dst_f16,
+ src0_as_f16, src1_as_f16, dst_t,
ptrs_src, ptrs_dst,
ne12, ne13,
ne23,
nb02, nb03,
nb12, nb13,
- dst->nb[2], dst->nb[3],
+ nbd2, nbd3,
r2, r3);
CUDA_CHECK(cudaGetLastError());
CUBLAS_CHECK(
cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
- &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
- (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
- &beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
+ alpha, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
+ (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
+ beta, ( void **) (ptrs_dst + 0*ne23), cu_data_type, ne01,
ne23,
- CUBLAS_COMPUTE_16F,
+ cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
if (ptrs_src_s != 0) {
}
#endif
- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
- to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
+ if (dst->op_params[0] == GGML_PREC_DEFAULT) {
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
+ to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
+
+ ggml_cuda_pool_free(dst_f16, dst_as);
+ }
ggml_cuda_pool_free(src1_as_f16, src1_as);
- ggml_cuda_pool_free(dst_f16, dst_as);
}
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
LLM_ARCH_BLOOM,
LLM_ARCH_STABLELM,
LLM_ARCH_QWEN,
+ LLM_ARCH_PHI2,
LLM_ARCH_UNKNOWN,
};
{ LLM_ARCH_BLOOM, "bloom" },
{ LLM_ARCH_STABLELM, "stablelm" },
{ LLM_ARCH_QWEN, "qwen" },
+ { LLM_ARCH_PHI2, "phi2" },
};
enum llm_kv {
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
+ {
+ LLM_ARCH_PHI2,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ },
+ },
{
LLM_ARCH_UNKNOWN,
struct ggml_tensor * output_norm;
struct ggml_tensor * output_norm_b;
struct ggml_tensor * output;
+ struct ggml_tensor * output_b;
std::vector<llama_layer> layers;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
+ case LLM_ARCH_PHI2:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
+
+ switch (hparams.n_layer) {
+ case 32: model.type = e_model::MODEL_3B; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ }
+ } break;
default: (void)0;
}
(void) main_gpu;
- enum ggml_backend_type llama_backend_offload = GGML_BACKEND_CPU;
+ enum ggml_backend_type llama_backend_offload = GGML_BACKEND_CPU;
enum ggml_backend_type llama_backend_offload_split = GGML_BACKEND_CPU;
#ifdef GGML_USE_CUBLAS
}
}
} break;
+ case LLM_ARCH_PHI2:
+ {
+ model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
+
+ // output
+ {
+ ggml_backend_type backend_norm;
+ ggml_backend_type backend_output;
+
+ if (n_gpu_layers > int(n_layer)) {
+ backend_norm = llama_backend_offload;
+ backend_output = llama_backend_offload;
+ } else {
+ backend_norm = GGML_BACKEND_CPU;
+ backend_output = GGML_BACKEND_CPU;
+ }
+
+ model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
+ model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm);
+ model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
+ model.output_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, backend_output);
+
+ if (backend_norm == GGML_BACKEND_GPU) {
+ vram_weights += ggml_nbytes(model.output_norm);
+ vram_weights += ggml_nbytes(model.output_norm_b);
+ vram_weights += ggml_nbytes(model.output);
+ vram_weights += ggml_nbytes(model.output_b);
+ }
+ }
+
+ const uint32_t n_ff = hparams.n_ff;
+
+ const int i_gpu_start = n_layer - n_gpu_layers;
+ model.layers.resize(n_layer);
+
+ for (uint32_t i = 0; i < n_layer; ++i) {
+ const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT
+ const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT
+
+ auto & layer = model.layers[i];
+
+ layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
+ layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend);
+
+ layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
+ layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend);
+
+ layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
+ layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend);
+
+ layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split);
+ layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend);
+
+ layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
+ layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend);
+
+ if (backend == GGML_BACKEND_GPU) {
+ vram_weights +=
+ ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) +
+ ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.bqkv) +
+ ggml_nbytes(layer.wo) + ggml_nbytes(layer.bo) +
+ ggml_nbytes(layer.ffn_up) + ggml_nbytes(layer.ffn_up_b) +
+ ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_down_b);
+ }
+ }
+ } break;
default:
throw std::runtime_error("unknown architecture");
}
// if max_alibi_bias > 0 then apply ALiBi
static struct ggml_tensor * llm_build_kqv(
struct ggml_context * ctx,
+ const llama_model & model,
const llama_hparams & hparams,
const llama_kv_cache & kv,
struct ggml_tensor * wo,
int32_t n_tokens,
int32_t n_kv,
float max_alibi_bias,
+ float scale,
const llm_build_cb & cb,
int il) {
const int64_t n_embd = hparams.n_embd;
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);
+ if (model.arch == LLM_ARCH_PHI2) {
+ // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
+ // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
+ ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
+ }
+
if (max_alibi_bias > 0.0f) {
// temporary branch until we figure out how to handle ggml_alibi through ggml_add
kq = ggml_scale(ctx, kq, kq_scale);
kq = ggml_soft_max(ctx, kq);
cb(kq, "kq_soft_max", il);
} else {
- kq = ggml_soft_max_ext(ctx, kq, kq_mask, 1.0f/sqrtf(float(n_embd_head)));
+ kq = ggml_soft_max_ext(ctx, kq, kq_mask, scale);
cb(kq, "kq_soft_max_ext", il);
}
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
- cur = llm_build_kqv(ctx0, hparams, kv_self,
+ cur = llm_build_kqv(ctx0, model, hparams, kv_self,
model.layers[il].wo, model.layers[il].bo,
- Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
+ Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
// apply ALiBi for 13B model
const float max_alibi_bias = model.type == MODEL_13B ? 8.0f : -1.0f;
- cur = llm_build_kqv(ctx0, hparams, kv_self,
+ cur = llm_build_kqv(ctx0, model, hparams, kv_self,
model.layers[il].wo, NULL,
- Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, cb, il);
+ Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
- cur = llm_build_kqv(ctx0, hparams, kv_self,
+ cur = llm_build_kqv(ctx0, model, hparams, kv_self,
model.layers[il].wo, NULL,
- Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
+ Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
- cur = llm_build_kqv(ctx0, hparams, kv_self,
+ cur = llm_build_kqv(ctx0, model, hparams, kv_self,
model.layers[il].wo, model.layers[il].bo,
- Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
+ Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
// TODO: not tested, could be broken
- cur = llm_build_kqv(ctx0, hparams, kv_self,
+ cur = llm_build_kqv(ctx0, model, hparams, kv_self,
model.layers[il].wo, model.layers[il].bo,
- Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
+ Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
- cur = llm_build_kqv(ctx0, hparams, kv_self,
+ cur = llm_build_kqv(ctx0, model, hparams, kv_self,
model.layers[il].wo, NULL,
- Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, cb, il);
+ Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
- cur = llm_build_kqv(ctx0, hparams, kv_self,
+ cur = llm_build_kqv(ctx0, model, hparams, kv_self,
model.layers[il].wo, model.layers[il].bo,
- Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, cb, il);
+ Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
- cur = llm_build_kqv(ctx0, hparams, kv_self,
+ cur = llm_build_kqv(ctx0, model, hparams, kv_self,
model.layers[il].wo, NULL,
- Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, cb, il);
+ Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
- cur = llm_build_kqv(ctx0, hparams, kv_self,
+ cur = llm_build_kqv(ctx0, model, hparams, kv_self,
model.layers[il].wo, NULL,
- Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
+ Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
cb(inpL, "inp_embd", -1);
// inp_pos - contains the positions
- struct ggml_tensor * inp_pos= ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+ struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_pos, "inp_pos", -1);
// KQ_scale
- struct ggml_tensor * KQ_scale= ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+ struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
cb(KQ_scale, "KQ_scale", -1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
- struct ggml_tensor * KQ_mask= ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+ struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
cb(KQ_mask, "KQ_mask", -1);
// shift the entire K-cache if needed
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
- cur = llm_build_kqv(ctx0, hparams, kv_self,
+ cur = llm_build_kqv(ctx0, model, hparams, kv_self,
model.layers[il].wo, NULL,
- Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
+ Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
ggml_build_forward_expand(gf, cur);
+ return gf;
+ }
+ struct ggml_cgraph * build_phi2() {
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+ struct ggml_tensor * cur;
+ struct ggml_tensor * attn_norm_output;
+ struct ggml_tensor * ffn_output;
+ struct ggml_tensor * inpL;
+
+ inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+ cb(inpL, "inp_embd", -1);
+
+ // inp_pos - contains the positions
+ struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+ cb(inp_pos, "inp_pos", -1);
+
+ // Q_scale
+ struct ggml_tensor * Q_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+ cb(Q_scale, "Q_scale", -1);
+
+ // KQ_scale
+ struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+ cb(KQ_scale, "KQ_scale", -1);
+
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+ cb(KQ_mask, "KQ_mask", -1);
+
+ // shift the entire K-cache if needed
+ if (do_rope_shift) {
+ llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+ }
+
+ for (int il = 0; il < n_layer; ++il) {
+ attn_norm_output = llm_build_norm(ctx0, inpL, hparams,
+ model.layers[il].attn_norm,
+ model.layers[il].attn_norm_b,
+ LLM_NORM, cb, il);
+ cb(attn_norm_output, "attn_norm", il);
+
+ // self-attention
+ {
+ cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output);
+ cb(cur, "wqkv", il);
+
+ cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+ cb(cur, "bqkv", il);
+
+ struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+ struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+ struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+ cb(Qcur, "Qcur", il);
+ cb(Kcur, "Kcur", il);
+ cb(Vcur, "Vcur", il);
+
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+ Qcur = ggml_rope_custom(
+ ctx0, Qcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
+ freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ cb(Qcur, "Qcur", il);
+
+ Qcur = ggml_scale(ctx0, Qcur, Q_scale);
+ cb(Qcur, "Qcur", il);
+
+ Kcur = ggml_rope_custom(
+ ctx0, Kcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
+ freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ cb(Kcur, "Kcur", il);
+
+ llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
+
+ cur = llm_build_kqv(ctx0, model, hparams, kv_self,
+ model.layers[il].wo, model.layers[il].bo,
+ Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f, cb, il);
+ cb(cur, "kqv_out", il);
+ }
+
+ // FF
+ {
+ ffn_output = llm_build_ffn(ctx0, attn_norm_output,
+ model.layers[il].ffn_up, model.layers[il].ffn_up_b,
+ NULL, NULL,
+ model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+ LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
+ cb(ffn_output, "ffn_out", il);
+ }
+
+ cur = ggml_add(ctx0, cur, ffn_output);
+ cb(cur, "l_out", il);
+
+ cur = ggml_add(ctx0, cur, inpL);
+ cb(cur, "l_out", il);
+
+ inpL = cur;
+ }
+
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.output_norm,
+ model.output_norm_b,
+ LLM_NORM, cb, -1);
+ cb(cur, "result_norm", -1);
+
+ cur = ggml_mul_mat(ctx0, model.output, cur);
+ cb(cur, "result_output_no_bias", -1);
+
+ cur = ggml_add(ctx0, cur, model.output_b);
+ cb(cur, "result_output", -1);
+
+ ggml_build_forward_expand(gf, cur);
+
return gf;
}
};
OFFLOAD_FUNC_FRC, // force offload
OFFLOAD_FUNC_KQV,
OFFLOAD_FUNC_NR,
- OFFLOAD_FUNC_EMB,
+ OFFLOAD_FUNC_EMB, // embeddings
OFFLOAD_FUNC_OUT,
};
{ "pos_embd", OFFLOAD_FUNC_NR },
{ "inp_pos", OFFLOAD_FUNC_FRC }, // this is often used for KQ ops (e.g. rope)
+ { "Q_scale", OFFLOAD_FUNC_FRC },
{ "KQ_scale", OFFLOAD_FUNC_FRC },
{ "KQ_mask", OFFLOAD_FUNC_FRC },
{ "K_shift", OFFLOAD_FUNC_FRC },
{ "l_out", OFFLOAD_FUNC },
{ "result_norm", OFFLOAD_FUNC_EMB },
+ { "result_output_no_bias", OFFLOAD_FUNC_EMB },
{ "result_output", OFFLOAD_FUNC_OUT },
};
bool alloc_inp_tokens = false;
bool alloc_inp_embd = false;
bool alloc_inp_pos = false;
+ bool alloc_inp_Q_scale = false;
bool alloc_inp_KQ_scale = false;
bool alloc_inp_KQ_mask = false;
bool alloc_inp_K_shift = false;
alloc_inp_pos = true;
}
- if (!alloc_inp_KQ_scale && strcmp(name, "KQ_scale") == 0) {
+ if (!alloc_inp_Q_scale && strcmp(name, "Q_scale") == 0) {
ggml_allocr_alloc(lctx.alloc, cur);
if (!ggml_allocr_is_measure(lctx.alloc)) {
ggml_set_f32(cur, 1.0f/sqrtf(float(n_embd_head)));
}
+ alloc_inp_Q_scale = true;
+ }
+
+ if (!alloc_inp_KQ_scale && strcmp(name, "KQ_scale") == 0) {
+ ggml_allocr_alloc(lctx.alloc, cur);
+
+ if (!ggml_allocr_is_measure(lctx.alloc)) {
+ const int64_t n_embd_head = model.hparams.n_embd_head();
+ if (model.arch == LLM_ARCH_PHI2) {
+ // with phi2, we scale the Q to avoid precision issues
+ // ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66
+ ggml_set_f32(cur, 1.0f);
+ } else {
+ ggml_set_f32(cur, 1.0f/sqrtf(float(n_embd_head)));
+ }
+ }
+
alloc_inp_KQ_scale = true;
}
{
result = llm.build_qwen();
} break;
+ case LLM_ARCH_PHI2:
+ {
+ result = llm.build_phi2();
+ } break;
default:
GGML_ASSERT(false);
}
ggml_allocr_alloc_graph(lctx.alloc, gf);
- struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
- struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
-
- GGML_ASSERT(strcmp(res->name, "result_output") == 0);
- GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
+ // the output is always the last tensor in the graph
+ struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
+ GGML_ASSERT(strcmp(res->name, "result_output") == 0);
+ // the embeddings could be the second to last tensor, or the third to last tensor
+ struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
+ if (strcmp(embeddings->name, "result_norm") != 0) {
+ embeddings = gf->nodes[gf->n_nodes - 3];
+ GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
+ }
#ifdef GGML_USE_CUBLAS
for (int i = 0; i < gf->n_leafs; i++) {