MK_LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
OBJS += ggml-cuda.o
NVCCFLAGS = --forward-unknown-to-host-compiler -use_fast_math
+
+ifdef LLAMA_DEBUG
+ NVCCFLAGS += -lineinfo
+endif
+
ifdef LLAMA_CUDA_NVCC
NVCC = $(LLAMA_CUDA_NVCC)
else
self.gguf_writer.add_embedding_length(n_embd)
if (n_ff := self.hparams.get("intermediate_size")) is not None:
self.gguf_writer.add_feed_forward_length(n_ff)
- if (n_head := self.hparams.get("num_attention_head")) is not None:
+ if (n_head := self.hparams.get("num_attention_heads")) is not None:
self.gguf_writer.add_head_count(n_head)
+ if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
+ self.gguf_writer.add_head_count_kv(n_head_kv)
+
+ if (n_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
+ self.gguf_writer.add_layer_norm_rms_eps(n_rms_eps)
+ if (n_experts := self.hparams.get("num_local_experts")) is not None:
+ self.gguf_writer.add_expert_count(n_experts)
+ if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
+ self.gguf_writer.add_expert_used_count(n_experts_used)
+
self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
def write_tensors(self):
return StableLMModel
if model_architecture == "QWenLMHeadModel":
return QwenModel
+ if model_architecture == "MixtralForCausalLM":
+ return MixtralModel
return Model
def _is_model_safetensors(self) -> bool:
return gguf.MODEL_ARCH.STABLELM
if arch == "QWenLMHeadModel":
return gguf.MODEL_ARCH.QWEN
+ if arch == "MixtralForCausalLM":
+ return gguf.MODEL_ARCH.LLAMA
raise NotImplementedError(f'Architecture "{arch}" not supported!')
self.gguf_writer.add_layer_norm_eps(1e-5)
+class MixtralModel(Model):
+ def set_vocab(self):
+ self._set_vocab_sentencepiece()
+
+
class QwenModel(Model):
@staticmethod
def token_bytes_to_string(b):
ARCH = gguf.MODEL_ARCH.LLAMA
DEFAULT_CONCURRENCY = 8
+
#
# data types
#
pass
-DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0'])
-DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0'])
-DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = [])
-DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0'])
+DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0'])
+DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0'])
+DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = [])
+DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0'])
@dataclass(frozen=True)
@dataclass
class Params:
- n_vocab: int
- n_embd: int
- n_layer: int
- n_ctx: int
- n_ff: int
- n_head: int
- n_head_kv: int
- f_norm_eps: float
+ n_vocab: int
+ n_embd: int
+ n_layer: int
+ n_ctx: int
+ n_ff: int
+ n_head: int
+ n_head_kv: int
+ n_experts: int | None = None
+ n_experts_used: int | None = None
+ f_norm_eps: float | None = None
rope_scaling_type: gguf.RopeScalingType | None = None
f_rope_freq_base: float | None = None
raise Exception("failed to guess 'n_ctx'. This model is unknown or unsupported.\n"
"Suggestion: provide 'config.json' of the model in the same directory containing model files.")
+ n_experts = None
+ n_experts_used = None
+
+ if "num_local_experts" in config:
+ n_experts = config["num_local_experts"]
+ n_experts_used = config["num_experts_per_tok"]
+
return Params(
n_vocab = config["vocab_size"],
n_embd = config["hidden_size"],
n_ff = config["intermediate_size"],
n_head = (n_head := config["num_attention_heads"]),
n_head_kv = config.get("num_key_value_heads", n_head),
+ n_experts = n_experts,
+ n_experts_used = n_experts_used,
f_norm_eps = config["rms_norm_eps"],
f_rope_freq_base = config.get("rope_theta"),
rope_scaling_type = rope_scaling_type,
def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
config = json.load(open(config_path))
+ n_experts = None
+ n_experts_used = None
+ f_rope_freq_base = None
+
# hack to determine LLaMA v1 vs v2 vs CodeLlama
- if config.get("rope_theta") == 1000000:
+ if config.get("moe"):
+ # Mixtral
+ n_ctx = 32768
+ elif config.get("rope_theta") == 1000000:
# CodeLlama
n_ctx = 16384
elif config["norm_eps"] == 1e-05:
# LLaMA v1
n_ctx = 2048
+ if "layers.0.feed_forward.w1.weight" in model:
+ n_ff = model["layers.0.feed_forward.w1.weight"].shape[0]
+
+ if config.get("moe"):
+ n_ff = model["layers.0.feed_forward.experts.0.w1.weight"].shape[0]
+ n_experts = config["moe"]["num_experts"]
+ n_experts_used = config["moe"]["num_experts_per_tok"]
+ f_rope_freq_base = 1e6
+
return Params(
n_vocab = model["tok_embeddings.weight"].shape[0],
n_embd = config["dim"],
n_layer = config["n_layers"],
n_ctx = n_ctx,
- n_ff = model["layers.0.feed_forward.w1.weight"].shape[0],
+ n_ff = n_ff,
n_head = (n_head := config["n_heads"]),
n_head_kv = config.get("n_kv_heads", n_head),
+ n_experts = n_experts,
+ n_experts_used = n_experts_used,
f_norm_eps = config["norm_eps"],
- f_rope_freq_base = config.get("rope_theta"),
+ f_rope_freq_base = config.get("rope_theta", f_rope_freq_base),
)
@staticmethod
self.gguf.add_rope_dimension_count(params.n_embd // params.n_head)
self.gguf.add_head_count (params.n_head)
self.gguf.add_head_count_kv (params.n_head_kv)
- self.gguf.add_layer_norm_rms_eps (params.f_norm_eps)
+
+ if params.n_experts:
+ self.gguf.add_expert_count(params.n_experts)
+
+ if params.n_experts_used:
+ self.gguf.add_expert_used_count(params.n_experts_used)
+
+ if params.f_norm_eps:
+ self.gguf.add_layer_norm_rms_eps(params.f_norm_eps)
+ else:
+ raise ValueError('f_norm_eps is None')
if params.f_rope_freq_base is not None:
self.gguf.add_rope_freq_base(params.f_rope_freq_base)
def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType:
- wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) +".weight"].data_type
+ wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) + ".weight"].data_type
if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32):
return GGMLFileType.AllF32
#include <algorithm>
+#include <assert.h>
+#include <atomic>
+#include <cinttypes>
#include <cstddef>
#include <cstdint>
-#include <cinttypes>
#include <float.h>
#include <limits>
#include <stdint.h>
#include <stdio.h>
-#include <atomic>
-#include <assert.h>
+#include <vector>
+
#if defined(GGML_USE_HIPBLAS)
#include <hip/hip_runtime.h>
}
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static __global__ void k_get_rows(const void * x, const int32_t * y, dst_t * dst, const int ncols) {
- const int col = (blockIdx.x*blockDim.x + threadIdx.x)*2;
- const int row = blockDim.y*blockIdx.y + threadIdx.y;
-
- if (col >= ncols) {
+static __global__ void k_get_rows(
+ const void * src0, const int32_t * src1, dst_t * dst,
+ int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+ /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+ /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+ /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+ size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
+
+ const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
+ const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+ const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
+ const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
+
+ if (i00 >= ne00) {
return;
}
- const int r = y[row];
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
- // copy x[r*ncols + col] to dst[row*ncols + col]
- const int xi = r*ncols + col;
- const int di = row*ncols + col;
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+ const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
- const int ib = xi/qk; // block index
- const int iqs = (xi%qk)/qr; // quant index
- const int iybs = di - di%qk; // y block start index
+ const int ib = i00/qk; // block index
+ const int iqs = (i00%qk)/qr; // quant index
+ const int iybs = i00 - i00%qk; // dst block start index
const int y_offset = qr == 1 ? 1 : qk/2;
// dequantize
dfloat2 v;
- dequantize_kernel(x, ib, iqs, v);
+ dequantize_kernel(src0_row, ib, iqs, v);
- dst[iybs + iqs + 0] = v.x;
- dst[iybs + iqs + y_offset] = v.y;
+ dst_row[iybs + iqs + 0] = v.x;
+ dst_row[iybs + iqs + y_offset] = v.y;
+}
+
+template<typename src0_t, typename dst_t>
+static __global__ void k_get_rows_float(
+ const src0_t * src0, const int32_t * src1, dst_t * dst,
+ int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+ /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+ /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+ /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+ size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
+
+ const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
+ const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+ const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
+ const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
+
+ if (i00 >= ne00) {
+ return;
+ }
+
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+ const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
+
+ dst_row[i00] = src0_row[i00];
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
}
template<int qk, int qr, dequantize_kernel_t dq>
-static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
+static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
+ const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
+ const dim3 block_nums(block_num_x, ne10, ne11*ne12);
+
+ // strides in elements
+ //const size_t s0 = nb0 / ggml_element_size(dst);
+ const size_t s1 = nb1 / ggml_element_size(dst);
+ const size_t s2 = nb2 / ggml_element_size(dst);
+ const size_t s3 = nb3 / ggml_element_size(dst);
+
+ const size_t s10 = nb10 / ggml_element_size(src1);
+ const size_t s11 = nb11 / ggml_element_size(src1);
+ const size_t s12 = nb12 / ggml_element_size(src1);
+ //const size_t s13 = nb13 / ggml_element_size(src1);
+
+ GGML_ASSERT(ne00 % 2 == 0);
+
+ k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
+ src0_dd, src1_dd, dst_dd,
+ ne00, /*ne01, ne02, ne03,*/
+ /*ne10, ne11,*/ ne12, /*ne13,*/
+ /* s0,*/ s1, s2, s3,
+ /* nb00,*/ nb01, nb02, nb03,
+ s10, s11, s12/*, s13*/);
+
+ (void) dst;
+}
+
+template<typename src0_t>
+static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
- const int block_num_x = (ncols + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
- const dim3 block_nums(block_num_x, nrows, 1);
- k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols);
+ const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
+ const dim3 block_nums(block_num_x, ne10, ne11*ne12);
+
+ // strides in elements
+ //const size_t s0 = nb0 / ggml_element_size(dst);
+ const size_t s1 = nb1 / ggml_element_size(dst);
+ const size_t s2 = nb2 / ggml_element_size(dst);
+ const size_t s3 = nb3 / ggml_element_size(dst);
+
+ const size_t s10 = nb10 / ggml_element_size(src1);
+ const size_t s11 = nb11 / ggml_element_size(src1);
+ const size_t s12 = nb12 / ggml_element_size(src1);
+ //const size_t s13 = nb13 / ggml_element_size(src1);
+
+ k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
+ src0_dd, src1_dd, dst_dd,
+ ne00, /*ne01, ne02, ne03,*/
+ /*ne10, ne11,*/ ne12, /*ne13,*/
+ /* s0,*/ s1, s2, s3,
+ /* nb00,*/ nb01, nb02, nb03,
+ s10, s11, s12/*, s13*/);
+
+ (void) dst;
}
template<float (*bin_op)(const float, const float)>
GGML_TENSOR_BINARY_OP_LOCALS
-
int nr0 = ne10/ne0;
int nr1 = ne11/ne1;
int nr2 = ne12/ne2;
int64_t ne12 = cne1[2];
int64_t ne13 = cne1[3];
- //size_t nb0 = cnb0[0];
+ size_t nb0 = cnb0[0];
size_t nb1 = cnb0[1];
size_t nb2 = cnb0[2];
size_t nb3 = cnb0[3];
- //size_t nb10 = cnb1[0];
+ size_t nb10 = cnb1[0];
size_t nb11 = cnb1[1];
size_t nb12 = cnb1[2];
size_t nb13 = cnb1[3];
- //size_t s0 = nb0 / sizeof(src1_t);
+ size_t s0 = nb0 / sizeof(src1_t);
size_t s1 = nb1 / sizeof(src1_t);
size_t s2 = nb2 / sizeof(src1_t);
size_t s3 = nb3 / sizeof(src1_t);
- //size_t s10 = nb10 / sizeof(src1_t);
+ size_t s10 = nb10 / sizeof(src1_t);
size_t s11 = nb11 / sizeof(src1_t);
size_t s12 = nb12 / sizeof(src1_t);
size_t s13 = nb13 / sizeof(src1_t);
+ GGML_ASSERT(s0 == 1);
+ GGML_ASSERT(s10 == 1);
const int block_size = 128;
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(ggml_is_contiguous(src1));
- GGML_ASSERT(ggml_is_contiguous(dst));
- const int ncols = src0->ne[0];
- const int nrows = ggml_nelements(src1);
+ GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+ GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
+ GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
const int32_t * src1_i32 = (const int32_t *) src1_d;
switch (src0->type) {
case GGML_TYPE_F16:
- get_rows_cuda<1, 1, convert_f16>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_F32:
- get_rows_cuda<1, 1, convert_f32>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_Q4_0:
- get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_Q4_1:
- get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_Q5_0:
- get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_Q5_1:
- get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_Q8_0:
- get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
default:
// TODO: k-quants
}
#endif
-static void ggml_cuda_mul_mat_id(const ggml_tensor * _src0, const ggml_tensor * _src1, ggml_tensor * dst) {
+static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
#if 0
-//#ifdef CUDA_USE_TENSOR_CORES
-// const bool use_tensor_cores = true;
-//#else
-// const bool use_tensor_cores = false;
-//#endif
-
ggml_cuda_mul_mat_id_cublas(dst);
-
// TODO: mmq/mmv support
-#else
- const struct ggml_tensor * ids = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
- const int id = dst->op_params[0];
+#endif
- int32_t * ids_dev = (int32_t *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
+ GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
- int32_t a_id;
- CUDA_CHECK(cudaMemcpyAsync(&a_id, ids_dev + id, sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
- CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+ const struct ggml_tensor * ids = src0;
+ const int32_t id = ((int32_t *) dst->op_params)[0];
+ const int32_t n_as = ((int32_t *) dst->op_params)[1];
- GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
- const struct ggml_tensor * src0 = dst->src[a_id + 2];
+ std::vector<char> ids_host(ggml_nbytes(ids));
+
+ if (ids->backend == GGML_BACKEND_GPU) {
+ const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
+ CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
+ CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+ } else {
+ memcpy(ids_host.data(), ids->data, ggml_nbytes(ids));
+ }
+
+ const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra;
+ const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra;
+
+ ggml_tensor_extra_gpu src1_row_extra;
+ ggml_tensor_extra_gpu dst_row_extra;
+
+ ggml_tensor src1_row = *src1;
+ ggml_tensor dst_row = *dst;
+
+ src1_row.ne[1] = 1;
+ dst_row.ne[1] = 1;
+
+ src1_row.nb[2] = src1_row.nb[1];
+ dst_row.nb[2] = dst_row.nb[1];
+
+ src1_row.nb[3] = src1_row.nb[1];
+ dst_row.nb[3] = dst_row.nb[1];
+
+ src1_row.extra = &src1_row_extra;
+ dst_row.extra = &dst_row_extra;
- ggml_cuda_mul_mat(src0, src1, dst);
-#endif
- (void) _src0;
- (void) _src1;
+ for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
+ //int32_t row_id;
+ //CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
+ //CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+
+ const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
+
+ GGML_ASSERT(row_id >= 0 && row_id < n_as);
+
+ const struct ggml_tensor * src0_row = dst->src[row_id + 2];
+
+ src1_row_extra.data_device[g_main_device] = (char *) src1_extra->data_device[g_main_device] + i01*src1->nb[1];
+ src1_row.data = (char *) src1->data + i01*src1->nb[1];
+
+ dst_row_extra.data_device[g_main_device] = (char *) dst_extra->data_device[g_main_device] + i01*dst->nb[1];
+ dst_row.data = (char *) dst->data + i01*dst->nb[1];
+
+ ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
+ }
}
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
}
return true;
} break;
+ case GGML_OP_GET_ROWS:
+ {
+ switch (op->src[0]->type) {
+ case GGML_TYPE_F16:
+ case GGML_TYPE_F32:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ return true;
+ default:
+ return false;
+ }
+ } break;
+ case GGML_OP_CPY:
+ {
+ ggml_type src0_type = op->src[0]->type;
+ ggml_type src1_type = op->src[1]->type;
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
+ return true;
+ }
+ return false;
+ } break;
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_TRANSPOSE:
case GGML_OP_NORM:
case GGML_OP_REPEAT:
- case GGML_OP_GET_ROWS:
case GGML_OP_DUP:
case GGML_OP_ADD:
case GGML_OP_MUL:
case GGML_OP_SCALE:
case GGML_OP_SQR:
case GGML_OP_CLAMP:
- case GGML_OP_CPY:
case GGML_OP_CONT:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
UNUSED(params);
}
-extern "C" int ggml_backend_cuda_reg_devices() {
+extern "C" int ggml_backend_cuda_reg_devices();
+
+int ggml_backend_cuda_reg_devices() {
int device_count = ggml_cuda_get_device_count();
//int device_count = 1; // DEBUG: some tools require delaying CUDA initialization
for (int i = 0; i < device_count; i++) {
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
+ GGML_METAL_DECL_KERNEL(cpy_f16_f32);
GGML_METAL_DECL_KERNEL(concat);
GGML_METAL_DECL_KERNEL(sqr);
GGML_METAL_DECL_KERNEL(sum_rows);
ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
} else {
char* buffer2 = malloc(len+1);
+ va_end(args);
+ va_start(args, format);
vsnprintf(buffer2, len+1, format, args);
buffer2[len] = 0;
ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
+ GGML_METAL_ADD_KERNEL(cpy_f16_f32);
GGML_METAL_ADD_KERNEL(concat);
GGML_METAL_ADD_KERNEL(sqr);
GGML_METAL_ADD_KERNEL(sum_rows);
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
+ GGML_METAL_DEL_KERNEL(cpy_f16_f32);
GGML_METAL_DEL_KERNEL(concat);
GGML_METAL_DEL_KERNEL(sqr);
GGML_METAL_DEL_KERNEL(sum_rows);
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
- case GGML_OP_TRANSPOSE:
case GGML_OP_PERMUTE:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_GET_ROWS:
case GGML_OP_CONCAT:
case GGML_OP_ADD:
case GGML_OP_MUL:
case GGML_OP_ROPE:
case GGML_OP_IM2COL:
case GGML_OP_ARGSORT:
- case GGML_OP_DUP:
- case GGML_OP_CPY:
- case GGML_OP_CONT:
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
return true;
+ case GGML_OP_CPY:
+ case GGML_OP_DUP:
+ case GGML_OP_CONT:
+ {
+ switch (op->src[0]->type) {
+ case GGML_TYPE_F32:
+ switch (op->type) {
+ case GGML_TYPE_F16:
+ case GGML_TYPE_F32:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ return true;
+ default:
+ return false;
+ }
+ case GGML_TYPE_F16:
+ switch (op->type) {
+ case GGML_TYPE_F16:
+ case GGML_TYPE_F32:
+ return true;
+ default:
+ return false;
+ }
+ default:
+ return false;
+ };
+ }
case GGML_OP_DIAG_MASK_INF:
- case GGML_OP_GET_ROWS:
{
return op->ne[0] % 4 == 0;
}
case GGML_OP_MUL:
case GGML_OP_DIV:
{
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(ggml_is_contiguous(src1));
-
bool bcast_row = false;
int64_t nb = ne00;
- if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
+ id<MTLComputePipelineState> pipeline = nil;
+
+ if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
// src1 is a row
GGML_ASSERT(ne11 == 1);
nb = ne00 / 4;
switch (dst->op) {
- case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
- case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
- case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
+ case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
+ case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
+ case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
default: GGML_ASSERT(false);
}
bcast_row = true;
} else {
switch (dst->op) {
- case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
- case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
- case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
+ case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
+ case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
+ case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
default: GGML_ASSERT(false);
}
}
+
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} else {
- const int nth = MIN(1024, ne0);
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
}
const float scale = ((float *) dst->op_params)[0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ if (id_src1) {
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ } else {
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+ }
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
else if (src0t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
- int64_t ny = (ne11 + nrows - 1)/nrows;
+ const int64_t ny = (ne11 + nrows - 1)/nrows;
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
}
GGML_ASSERT(src0t == GGML_TYPE_I32);
- const int n_as = ne00;
+ const int n_as = ((int32_t *) dst->op_params)[1];
// TODO: make this more general
GGML_ASSERT(n_as <= 8);
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel
- int ne11_mm_min = 0;
+ int ne11_mm_min = 1;
const int idx = ((int32_t *) dst->op_params)[0];
+ // batch size
+ GGML_ASSERT(ne01 == ne11);
+
+ const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
+
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
- ne11 > ne11_mm_min) {
+ // !!!
+ // TODO: for now, always use mat-vec kernels until we figure out how to improve the
+ // indirect matrix multiplication
+ // !!!
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
switch (src2->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
- [encoder setBytes:&idx length:sizeof(idx) atIndex:15];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
// TODO: how to make this an array? read Metal docs
for (int j = 0; j < n_as; ++j) {
struct ggml_tensor * src_cur = dst->src[2 + j];
size_t offs_src_cur = 0;
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
}
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+
+ // TODO: processing one row at a time (ne11 -> 1) is not efficient
+ [encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+ } else {
+ int nth0 = 32;
+ int nth1 = 1;
+ int nrows = 1;
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
+
+ // use custom matrix x vector kernel
+ switch (src2t) {
+ case GGML_TYPE_F32:
+ {
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
+ } break;
+ case GGML_TYPE_F16:
+ {
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+ nth0 = 32;
+ nth1 = 1;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
+ } break;
+ case GGML_TYPE_Q4_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
+ } break;
+ case GGML_TYPE_Q4_1:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
+ } break;
+ case GGML_TYPE_Q5_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
+ } break;
+ case GGML_TYPE_Q5_1:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
+ } break;
+ case GGML_TYPE_Q8_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
+ } break;
+ case GGML_TYPE_Q2_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
+ } break;
+ case GGML_TYPE_Q3_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
+ } break;
+ case GGML_TYPE_Q4_K:
+ {
+ nth0 = 4; //1;
+ nth1 = 8; //32;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
+ } break;
+ case GGML_TYPE_Q5_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
+ } break;
+ case GGML_TYPE_Q6_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
+ } break;
+ default:
+ {
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
+ GGML_ASSERT(false && "not implemented");
+ }
+ };
+
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
+ [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
+ // TODO: how to make this an array? read Metal docs
+ for (int j = 0; j < n_as; ++j) {
+ struct ggml_tensor * src_cur = dst->src[2 + j];
+
+ size_t offs_src_cur = 0;
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
+
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
+ }
+
+ if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
+ src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
+ src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src2t == GGML_TYPE_Q4_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src2t == GGML_TYPE_Q3_K) {
+#ifdef GGML_QKK_64
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+#else
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+#endif
+ }
+ else if (src2t == GGML_TYPE_Q5_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src2t == GGML_TYPE_Q6_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ } else {
+ const int64_t ny = (_ne1 + nrows - 1)/nrows;
+ [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
}
} break;
case GGML_OP_GET_ROWS:
default: GGML_ASSERT(false && "not implemented");
}
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
-
- const int64_t n = ggml_nelements(src1);
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
+ [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
+ [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
+ [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
} break;
case GGML_OP_RMS_NORM:
{
{
switch (dstt) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
- case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
default: GGML_ASSERT(false && "not implemented");
};
} break;
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
- device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
- device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
- device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+ device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+ device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
+ device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
// parallel max
float lmax = -INFINITY;
pdst[i00] = exp_psrc0;
}
+ // This barrier fixes a failing test
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+ threadgroup_barrier(mem_flags::mem_none);
+
float sum = simd_sum(lsum);
+
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
- device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
- device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
- device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+ device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
// parallel max
float4 lmax4 = -INFINITY;
}
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
+
+ // This barrier fixes a failing test
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+ threadgroup_barrier(mem_flags::mem_none);
+
float sum = simd_sum(lsum);
+
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
// giard against the number of rows not being divisible by
// N_DST, so this is another explicit assumption of the implementation.
template<typename block_q_type, int nr, int nsg, int nw>
-void mul_vec_q_n_f32(
+void mul_vec_q_n_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mv_q4_1_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mv_q5_0_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mv_q5_1_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
}
#define NB_Q8_0 8
-kernel void kernel_mul_mv_q8_0_f32(
+void kernel_mul_mv_q8_0_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0 [[buffer(15)]],
- constant int64_t & ne1 [[buffer(16)]],
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int nr = N_DST;
const int nsg = N_SIMDGROUP;
const int nw = N_SIMDWIDTH;
}
}
+[[host_name("kernel_mul_mv_q8_0_f32")]]
+kernel void kernel_mul_mv_q8_0_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+}
+
#define N_F32_F32 4
-kernel void kernel_mul_mv_f32_f32(
+void kernel_mul_mv_f32_f32_impl(
device const char * src0,
device const char * src1,
device float * dst,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
}
}
+[[host_name("kernel_mul_mv_f32_f32")]]
+kernel void kernel_mul_mv_f32_f32(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+ kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
+}
+
#define N_F16_F16 4
kernel void kernel_mul_mv_f16_f16(
}
}
-kernel void kernel_mul_mv_f16_f32_1row(
+void kernel_mul_mv_f16_f32_1row_impl(
device const char * src0,
device const char * src1,
device float * dst,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
+}
+[[host_name("kernel_mul_mv_f16_f32_1row")]]
+kernel void kernel_mul_mv_f16_f32_1row(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+ kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
}
#define N_F16_F32 4
-kernel void kernel_mul_mv_f16_f32(
+void kernel_mul_mv_f16_f32_impl(
device const char * src0,
device const char * src1,
device float * dst,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
}
}
+[[host_name("kernel_mul_mv_f16_f32")]]
+kernel void kernel_mul_mv_f16_f32(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+ kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
+}
+
// Assumes row size (ne00) is a multiple of 4
kernel void kernel_mul_mv_f16_f32_l4(
device const char * src0,
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
kernel void kernel_cpy_f16_f16(
- device const half * src0,
- device half * dst,
+ device const half * src0,
+ device half * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
}
}
+kernel void kernel_cpy_f16_f32(
+ device const half * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+ device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ dst_data[i00] = src[0];
+ }
+}
+
kernel void kernel_cpy_f32_f16(
device const float * src0,
device half * dst,
//====================================== dot products =========================
-kernel void kernel_mul_mv_q2_K_f32(
+void kernel_mul_mv_q2_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0 [[buffer(15)]],
- constant int64_t & ne1 [[buffer(16)]],
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
}
}
-#if QK_K == 256
-kernel void kernel_mul_mv_q3_K_f32(
+[[host_name("kernel_mul_mv_q2_K_f32")]]
+kernel void kernel_mul_mv_q2_K_f32(
device const void * src0,
device const float * src1,
device float * dst,
constant uint & r2 [[buffer(17)]],
constant uint & r3 [[buffer(18)]],
uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
+#if QK_K == 256
+void kernel_mul_mv_q3_K_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int nb = ne00/QK_K;
}
}
#else
-kernel void kernel_mul_mv_q3_K_f32(
+void kernel_mul_mv_q3_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0 [[buffer(15)]],
- constant int64_t & ne1 [[buffer(16)]],
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
}
#endif
+[[host_name("kernel_mul_mv_q3_K_f32")]]
+kernel void kernel_mul_mv_q3_K_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01[[buffer(4)]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0 [[buffer(15)]],
+ constant int64_t & ne1 [[buffer(16)]],
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
#if QK_K == 256
-kernel void kernel_mul_mv_q4_K_f32(
+void kernel_mul_mv_q4_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01 [[buffer(4)]],
- constant int64_t & ne02 [[buffer(5)]],
- constant int64_t & ne10 [[buffer(9)]],
- constant int64_t & ne12 [[buffer(11)]],
- constant int64_t & ne0 [[buffer(15)]],
- constant int64_t & ne1 [[buffer(16)]],
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
}
}
#else
-kernel void kernel_mul_mv_q4_K_f32(
+void kernel_mul_mv_q4_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0 [[buffer(15)]],
- constant int64_t & ne1 [[buffer(16)]],
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int ix = tiisg/4; // 0...7
}
#endif
-kernel void kernel_mul_mv_q5_K_f32(
+[[host_name("kernel_mul_mv_q4_K_f32")]]
+kernel void kernel_mul_mv_q4_K_f32(
device const void * src0,
device const float * src1,
device float * dst,
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_q5_K_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
const int nb = ne00/QK_K;
const int64_t r0 = tgpig.x;
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
}
}
-
}
-kernel void kernel_mul_mv_q6_K_f32(
+[[host_name("kernel_mul_mv_q5_K_f32")]]
+kernel void kernel_mul_mv_q5_K_f32(
device const void * src0,
device const float * src1,
device float * dst,
constant uint & r2 [[buffer(17)]],
constant uint & r3 [[buffer(18)]],
uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_q6_K_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
const uint8_t kmask1 = 0x03;
const uint8_t kmask2 = 0x0C;
}
}
+[[host_name("kernel_mul_mv_q6_K_f32")]]
+kernel void kernel_mul_mv_q6_K_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01[[buffer(4)]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0 [[buffer(15)]],
+ constant int64_t & ne1 [[buffer(16)]],
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
//============================= templates and their specializations =============================
// NOTE: this is not dequantizing - we are simply fitting the template
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
kernel void kernel_get_rows(
device const void * src0,
- device const int * src1,
+ device const char * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
constant uint64_t & nb1,
- uint tgpig[[threadgroup_position_in_grid]],
+ constant uint64_t & nb2,
+ uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
- uint tptg[[threads_per_threadgroup]]) {
- const int i = tgpig;
- const int r = ((device int32_t *) src1)[i];
+ uint3 tptg [[threads_per_threadgroup]]) {
+ //const int64_t i = tgpig;
+ //const int64_t r = ((device int32_t *) src1)[i];
- for (int ind = tiitg; ind < ne00/16; ind += tptg) {
+ const int64_t i10 = tgpig.x;
+ const int64_t i11 = tgpig.y;
+
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+ const int64_t i02 = i11;
+
+ for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
float4x4 temp;
dequantize_func(
- ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
- *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
+ ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
+ *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
+ }
+}
+
+kernel void kernel_get_rows_f32(
+ device const void * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg [[threads_per_threadgroup]]) {
+ const int64_t i10 = tgpig.x;
+ const int64_t i11 = tgpig.y;
+
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+ const int64_t i02 = i11;
+
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
+ ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
+ }
+}
+
+kernel void kernel_get_rows_f16(
+ device const void * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg [[threads_per_threadgroup]]) {
+ const int64_t i10 = tgpig.x;
+ const int64_t i11 = tgpig.y;
+
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+ const int64_t i02 = i11;
+
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
+ ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
}
}
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
kernel void kernel_mul_mm_id(
- device const int32_t * ids,
+ device const uchar * ids,
device const uchar * src1,
- device float * dst,
+ device uchar * dst,
+ constant int64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne02,
constant int64_t & nb01,
constant int64_t & nb02,
constant int64_t & ne12,
+ constant int64_t & ne13,
constant int64_t & nb10,
constant int64_t & nb11,
constant int64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
+ constant int64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
uint sgitg[[simdgroup_index_in_threadgroup]]) {
device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
- src0[ids[idx]],
- src1,
- dst,
+ src0[id],
+ src1 + bid*nb11,
+ (device float *) (dst + bid*nb1),
ne00,
ne02,
nb01,
#define QK_NL 4
#endif
+//
+// get rows
+//
+
typedef void (get_rows_t)(
device const void * src0,
- device const int * src1,
+ device const char * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
constant uint64_t & nb1,
- uint, uint, uint);
+ constant uint64_t & nb2,
+ uint3, uint, uint3);
-template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
-template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
+//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
+//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
+//
+// matrix-matrix multiplication
+//
+
typedef void (mat_mm_t)(
device const uchar * src0,
device const uchar * src1,
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
+//
+// indirect matrix-matrix multiplication
+//
+
typedef void (mat_mm_id_t)(
- device const int32_t * ids,
+ device const uchar * ids,
device const uchar * src1,
- device float * dst,
+ device uchar * dst,
+ constant int64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne02,
constant int64_t & nb01,
constant int64_t & nb02,
constant int64_t & ne12,
+ constant int64_t & ne13,
constant int64_t & nb10,
constant int64_t & nb11,
constant int64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
+ constant int64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
+
+//
+// matrix-vector multiplication
+//
+
+[[host_name("kernel_mul_mv_id_f32_f32")]]
+kernel void kernel_mul_mv_id_f32_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_f32_f32_impl(
+ src0[id],
+ src1 + bid*nb11,
+ (device float *) (dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ nb00,
+ nb01,
+ nb02,
+ ne10,
+ ne11,
+ ne12,
+ nb10,
+ nb11,
+ nb12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg);
+}
+
+[[host_name("kernel_mul_mv_id_f16_f32")]]
+kernel void kernel_mul_mv_id_f16_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_f16_f32_impl(
+ src0[id],
+ src1 + bid*nb11,
+ (device float *) (dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ nb00,
+ nb01,
+ nb02,
+ ne10,
+ ne11,
+ ne12,
+ nb10,
+ nb11,
+ nb12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg);
+}
+
+[[host_name("kernel_mul_mv_id_q8_0_f32")]]
+kernel void kernel_mul_mv_id_q8_0_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q8_0_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q4_0_f32")]]
+kernel void kernel_mul_mv_id_q4_0_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q4_1_f32")]]
+kernel void kernel_mul_mv_id_q4_1_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q5_0_f32")]]
+kernel void kernel_mul_mv_id_q5_0_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q5_1_f32")]]
+kernel void kernel_mul_mv_id_q5_1_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q2_K_f32")]]
+kernel void kernel_mul_mv_id_q2_K_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q2_K_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q3_K_f32")]]
+kernel void kernel_mul_mv_id_q3_K_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q3_K_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q4_K_f32")]]
+kernel void kernel_mul_mv_id_q4_K_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q4_K_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q5_K_f32")]]
+kernel void kernel_mul_mv_id_q5_K_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q5_K_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q6_K_f32")]]
+kernel void kernel_mul_mv_id_q6_K_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q6_K_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
struct ggml_tensor * ggml_mul_mat_id(
struct ggml_context * ctx,
- struct ggml_tensor * as[],
+ struct ggml_tensor * const as[],
+ int n_as,
struct ggml_tensor * ids,
int id,
struct ggml_tensor * b) {
- int64_t n_as = ids->ne[0];
-
GGML_ASSERT(ids->type == GGML_TYPE_I32);
- GGML_ASSERT(ggml_is_vector(ids));
+ GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1);
+ GGML_ASSERT(ids->ne[1] == b->ne[1]);
+ GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2);
- GGML_ASSERT(id >= 0 && id < n_as);
+ GGML_ASSERT(id >= 0 && id < ids->ne[0]);
bool is_node = false;
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne);
ggml_set_op_params_i32(result, 0, id);
+ ggml_set_op_params_i32(result, 1, n_as);
result->op = GGML_OP_MUL_MAT_ID;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = ids;
result->src[1] = b;
- for (int64_t i = 0; i < n_as; i++) {
+ for (int i = 0; i < n_as; i++) {
struct ggml_tensor * a = as[i];
GGML_ASSERT(ggml_are_same_shape(as[0], a));
GGML_ASSERT(ggml_can_mul_mat(a, b));
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
- GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32);
+ GGML_ASSERT(a->ne[2] == b->ne[1]);
+ GGML_ASSERT(b->ne[3] == 1);
+ GGML_ASSERT(b->type == GGML_TYPE_I32);
bool is_node = false;
// TODO: implement non F32 return
//struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
- struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0]);
+ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
result->op = GGML_OP_GET_ROWS;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
+ // NOTE: with GGML_OP_MUL_MAT_ID we don't want to go through the BLAS branch because it will dequantize (to_float)
+ // all the experts for each batch element and the processing would become incredibly slow
// TODO: find the optimal values for these
- if (ggml_is_contiguous(src0) &&
+ if (dst->op != GGML_OP_MUL_MAT_ID &&
+ ggml_is_contiguous(src0) &&
ggml_is_contiguous(src1) &&
//src0->type == GGML_TYPE_F32 &&
src1->type == GGML_TYPE_F32 &&
}
#endif
+// off1 = offset in i11 and i1
+// cne1 = ne11 and ne1
+// in a normal matrix multiplication, off1 = 0 and cne1 = ne1
+// during GGML_TASK_INIT, the full src1 is converted regardless of off1 and cne1
static void ggml_compute_forward_mul_mat(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
- struct ggml_tensor * dst) {
+ struct ggml_tensor * dst,
+ int64_t off1, int64_t cne1) {
int64_t t0 = ggml_perf_time_us();
UNUSED(t0);
const int64_t i03 = i13/r3;
const int64_t i02 = i12/r2;
- const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
- const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
-
- float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
+ const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
+ const float * y = (float *) ((char *) src1->data + off1*nb11 + i12*nb12 + i13*nb13);
+ float * d = (float *) ((char *) dst->data + off1*nb1 + i12*nb2 + i13*nb3);
if (type != GGML_TYPE_F32) {
float * const wdata = params->wdata;
}
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
- ne11, ne01, ne10,
- 1.0f, y, ne10,
- x, ne00,
- 0.0f, d, ne01);
+ cne1, ne01, ne10,
+ 1.0f, y, ne10,
+ x, ne00,
+ 0.0f, d, ne01);
}
}
const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
assert(params->wsize >= ne11*ne12*ne13*row_size);
+ assert(src1->type == GGML_TYPE_F32);
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
const int64_t nr0 = ne01; // src0 rows
- const int64_t nr1 = ne11*ne12*ne13; // src1 rows
+ const int64_t nr1 = cne1*ne12*ne13; // src1 rows
//printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
- const int64_t i13 = (ir1/(ne12*ne11));
- const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11;
- const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11);
+ const int64_t i13 = (ir1/(ne12*cne1));
+ const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
+ const int64_t i11 = (ir1 - i13*ne12*cne1 - i12*cne1) + off1;
// broadcast src0 into src1
const int64_t i03 = i13/r3;
static void ggml_compute_forward_mul_mat_id(
const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
- const struct ggml_tensor * ids = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- const int id = ggml_get_op_params_i32(dst, 0);
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ // during GGML_TASK_INIT the entire src1 is converted to vec_dot_type
+ ggml_compute_forward_mul_mat(params, dst->src[2], src1, dst, 0, dst->ne[1]);
+ return;
+ }
- const int a_id = ((int32_t *)ids->data)[id];
+ const struct ggml_tensor * ids = src0;
+ const int id = ggml_get_op_params_i32(dst, 0);
+ const int n_as = ggml_get_op_params_i32(dst, 1);
- GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
+ for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
+ const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
- const struct ggml_tensor * src0 = dst->src[a_id + 2];
+ GGML_ASSERT(row_id >= 0 && row_id < n_as);
- ggml_compute_forward_mul_mat(params, src0, src1, dst);
+ const struct ggml_tensor * src0_row = dst->src[row_id + 2];
+ ggml_compute_forward_mul_mat(params, src0_row, src1, dst, i01, 1);
+ }
}
// ggml_compute_forward_out_prod
return;
}
- const int nc = src0->ne[0];
- const int nr = ggml_nelements(src1);
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const int64_t nc = ne00;
+ const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
+
const enum ggml_type type = src0->type;
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
- assert( dst->ne[0] == nc);
- assert( dst->ne[1] == nr);
- assert(src0->nb[0] == ggml_type_size(type));
+ assert(ne0 == nc);
+ assert(ne02 == ne11);
+ assert(nb00 == ggml_type_size(type));
+ assert(ggml_nrows(dst) == nr);
- for (int i = 0; i < nr; ++i) {
- const int r = ((int32_t *) src1->data)[i];
+ // TODO: multi-thread
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
+ for (int64_t i10 = 0; i10 < ne10; ++i10) {
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
- dequantize_row_q(
- (const void *) ((char *) src0->data + r*src0->nb[1]),
- (float *) ((char *) dst->data + i*dst->nb[1]), nc);
+ dequantize_row_q(
+ (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
+ }
+ }
}
}
return;
}
- const int nc = src0->ne[0];
- const int nr = ggml_nelements(src1);
+ GGML_TENSOR_BINARY_OP_LOCALS
- assert( dst->ne[0] == nc);
- assert( dst->ne[1] == nr);
- assert(src0->nb[0] == sizeof(ggml_fp16_t));
+ const int64_t nc = ne00;
+ const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
- for (int i = 0; i < nr; ++i) {
- const int r = ((int32_t *) src1->data)[i];
+ assert(ne0 == nc);
+ assert(ne02 == ne11);
+ assert(nb00 == sizeof(ggml_fp16_t));
+ assert(ggml_nrows(dst) == nr);
- for (int j = 0; j < nc; ++j) {
- ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j];
- ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
+ // TODO: multi-thread
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
+ for (int64_t i10 = 0; i10 < ne10; ++i10) {
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+ ggml_fp16_to_fp32_row(
+ (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
+ }
}
}
}
return;
}
- const int nc = src0->ne[0];
- const int nr = ggml_nelements(src1);
+ GGML_TENSOR_BINARY_OP_LOCALS
- assert( dst->ne[0] == nc);
- assert( dst->ne[1] == nr);
- assert(src0->nb[0] == sizeof(float));
+ const int64_t nc = ne00;
+ const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
- for (int i = 0; i < nr; ++i) {
- const int r = ((int32_t *) src1->data)[i];
+ assert(ne0 == nc);
+ assert(ne02 == ne11);
+ assert(nb00 == sizeof(float));
+ assert(ggml_nrows(dst) == nr);
- ggml_vec_cpy_f32(nc,
- (float *) ((char *) dst->data + i*dst->nb[1]),
- (float *) ((char *) src0->data + r*src0->nb[1]));
+ // TODO: multi-thread
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
+ for (int64_t i10 = 0; i10 < ne10; ++i10) {
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+ ggml_vec_cpy_f32(nc,
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
+ (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
+ }
+ }
}
}
} break;
case GGML_OP_MUL_MAT:
{
- ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
+ ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor, 0, tensor->ne[1]);
} break;
case GGML_OP_MUL_MAT_ID:
{
- ggml_compute_forward_mul_mat_id(params, tensor);
+ ggml_compute_forward_mul_mat_id(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_OUT_PROD:
{
#define GGML_MAX_DIMS 4
#define GGML_MAX_PARAMS 2048
#define GGML_MAX_CONTEXTS 64
-#define GGML_MAX_SRC 6
+#define GGML_MAX_SRC 10
#define GGML_MAX_NAME 64
#define GGML_MAX_OP_PARAMS 64
#define GGML_DEFAULT_N_THREADS 4
// ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
GGML_API struct ggml_tensor * ggml_mul_mat_id(
struct ggml_context * ctx,
- struct ggml_tensor * as[],
+ struct ggml_tensor * const as[],
+ int n_as,
struct ggml_tensor * ids,
int id,
struct ggml_tensor * b);
struct ggml_context * ctx,
struct ggml_tensor * a);
+ // supports 3D: a->ne[2] == b->ne[1]
GGML_API struct ggml_tensor * ggml_get_rows(
struct ggml_context * ctx,
struct ggml_tensor * a,
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
+ EXPERT_COUNT = "{arch}.expert_count"
+ EXPERT_USED_COUNT = "{arch}.expert_used_count"
class Attention:
HEAD_COUNT = "{arch}.attention.head_count"
ATTN_NORM = auto()
ATTN_NORM_2 = auto()
ATTN_ROT_EMBD = auto()
+ FFN_GATE_INP = auto()
+ FFN_NORM = auto()
FFN_GATE = auto()
FFN_DOWN = auto()
FFN_UP = auto()
- FFN_NORM = auto()
+ FFN_GATE_EXP = auto()
+ FFN_DOWN_EXP = auto()
+ FFN_UP_EXP = auto()
ATTN_Q_NORM = auto()
ATTN_K_NORM = auto()
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
+ MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp",
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
+ MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate.{xid}",
+ MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down.{xid}",
+ MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up.{xid}",
}
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.GPTNEOX: [
MODEL_TENSOR.TOKEN_EMBD,
def add_clamp_kqv(self, value: float) -> None:
self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
+ def add_expert_count(self, count: int) -> None:
+ self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count)
+
+ def add_expert_used_count(self, count: int) -> None:
+ self.add_uint32(Keys.LLM.EXPERT_USED_COUNT.format(arch=self.arch), count)
+
def add_layer_norm_eps(self, value: float) -> None:
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
"model.layers.{bid}.ln2", # yi
),
+ MODEL_TENSOR.FFN_GATE_INP: (
+ "layers.{bid}.feed_forward.gate", # mixtral
+ "model.layers.{bid}.block_sparse_moe.gate", # mixtral
+ ),
+
# Feed-forward up
MODEL_TENSOR.FFN_UP: (
"gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox
"transformer.h.{bid}.mlp.w1", # qwen
),
+ MODEL_TENSOR.FFN_UP_EXP: (
+ "layers.{bid}.feed_forward.experts.{xid}.w3", # mixtral
+ "model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", # mixtral
+ ),
+
# Feed-forward gate
MODEL_TENSOR.FFN_GATE: (
- "model.layers.{bid}.mlp.gate_proj", # llama-hf refact
- "layers.{bid}.feed_forward.w1", # llama-pth
- "transformer.h.{bid}.mlp.w2", # qwen
+ "model.layers.{bid}.mlp.gate_proj", # llama-hf refact
+ "layers.{bid}.feed_forward.w1", # llama-pth
+ "transformer.h.{bid}.mlp.w2", # qwen
+ ),
+
+ MODEL_TENSOR.FFN_GATE_EXP: (
+ "layers.{bid}.feed_forward.experts.{xid}.w1", # mixtral
+ "model.layers.{bid}.block_sparse_moe.experts.{xid}.w1", # mixtral
),
# Feed-forward down
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
),
+ MODEL_TENSOR.FFN_DOWN_EXP: (
+ "layers.{bid}.feed_forward.experts.{xid}.w2", # mixtral
+ "model.layers.{bid}.block_sparse_moe.experts.{xid}.w2", # mixtral
+ ),
+
MODEL_TENSOR.ATTN_Q_NORM: (
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
),
for tensor, keys in self.block_mappings_cfg.items():
if tensor not in MODEL_TENSORS[arch]:
continue
- tensor_name = TENSOR_NAMES[tensor].format(bid = bid)
- self.mapping[tensor_name] = (tensor, tensor_name)
- for key in keys:
- key = key.format(bid = bid)
- self.mapping[key] = (tensor, tensor_name)
+ # TODO: make this configurable
+ n_experts = 8
+ for xid in range(n_experts):
+ tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid)
+ self.mapping[tensor_name] = (tensor, tensor_name)
+ for key in keys:
+ key = key.format(bid = bid, xid = xid)
+ self.mapping[key] = (tensor, tensor_name)
def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
result = self.mapping.get(key)
[tool.poetry]
name = "gguf"
-version = "0.6.0"
+version = "0.7.0"
description = "Read and write ML models in GGUF for GGML"
authors = ["GGML <ggml@ggml.ai>"]
packages = [
#define LLAMA_ATTRIBUTE_FORMAT(...)
#endif
-#define LLAMA_MAX_NODES 8192
+#define LLAMA_MAX_NODES 8192
+#define LLAMA_MAX_EXPERTS 8
//
// logging
LLM_KV_FEED_FORWARD_LENGTH,
LLM_KV_USE_PARALLEL_RESIDUAL,
LLM_KV_TENSOR_DATA_LAYOUT,
+ LLM_KV_EXPERT_COUNT,
+ LLM_KV_EXPERT_USED_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV,
{ LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" },
{ LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" },
{ LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
+ { LLM_KV_EXPERT_COUNT, "%s.expert_count" },
+ { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_NORM_2,
LLM_TENSOR_ATTN_ROT_EMBD,
+ LLM_TENSOR_FFN_GATE_INP,
+ LLM_TENSOR_FFN_NORM,
LLM_TENSOR_FFN_GATE,
LLM_TENSOR_FFN_DOWN,
LLM_TENSOR_FFN_UP,
- LLM_TENSOR_FFN_NORM,
+ LLM_TENSOR_FFN_DOWN_EXP,
+ LLM_TENSOR_FFN_GATE_EXP,
+ LLM_TENSOR_FFN_UP_EXP,
LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K_NORM,
};
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
+ { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
+ { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
},
},
{
std::string operator()(llm_tensor tensor, const std::string & suffix, int bid) const {
return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid) + "." + suffix;
}
+
+ std::string operator()(llm_tensor tensor, const std::string & suffix, int bid, int xid) const {
+ return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid, xid) + "." + suffix;
+ }
};
//
uint32_t n_layer;
uint32_t n_rot;
uint32_t n_ff;
+ uint32_t n_expert = 0;
+ uint32_t n_expert_used = 0;
float f_norm_eps;
float f_norm_rms_eps;
float f_max_alibi_bias;
bool operator!=(const llama_hparams & other) const {
- if (this->vocab_only != other.vocab_only) return true;
- if (this->n_vocab != other.n_vocab) return true;
- if (this->n_ctx_train != other.n_ctx_train) return true;
- if (this->n_embd != other.n_embd) return true;
- if (this->n_head != other.n_head) return true;
- if (this->n_head_kv != other.n_head_kv) return true;
- if (this->n_layer != other.n_layer) return true;
- if (this->n_rot != other.n_rot) return true;
- if (this->n_ff != other.n_ff) return true;
+ if (this->vocab_only != other.vocab_only) return true;
+ if (this->n_vocab != other.n_vocab) return true;
+ if (this->n_ctx_train != other.n_ctx_train) return true;
+ if (this->n_embd != other.n_embd) return true;
+ if (this->n_head != other.n_head) return true;
+ if (this->n_head_kv != other.n_head_kv) return true;
+ if (this->n_layer != other.n_layer) return true;
+ if (this->n_rot != other.n_rot) return true;
+ if (this->n_ff != other.n_ff) return true;
+ if (this->n_expert != other.n_expert) return true;
+ if (this->n_expert_used != other.n_expert_used) return true;
+
if (this->rope_finetuned != other.rope_finetuned) return true;
if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
struct ggml_tensor * ffn_down; // w2
struct ggml_tensor * ffn_up; // w3
+ // ff MoE
+ struct ggml_tensor * ffn_gate_inp;
+ struct ggml_tensor * ffn_gate_exp[LLAMA_MAX_EXPERTS];
+ struct ggml_tensor * ffn_down_exp[LLAMA_MAX_EXPERTS];
+ struct ggml_tensor * ffn_up_exp [LLAMA_MAX_EXPERTS];
+
// ff bias
struct ggml_tensor * ffn_down_b; // b2
struct ggml_tensor * ffn_up_b; // b3
ml.get_key (LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff);
ml.get_key (LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head);
ml.get_key (LLM_KV_BLOCK_COUNT, hparams.n_layer);
+ ml.get_key (LLM_KV_EXPERT_COUNT, hparams.n_expert, false);
+ ml.get_key (LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
+
+ GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS);
+ GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert);
+ if (hparams.n_expert > 0) {
+ GGML_ASSERT(hparams.n_expert_used > 0);
+ } else {
+ GGML_ASSERT(hparams.n_expert_used == 0);
+ }
// n_head_kv is optional, default to n_head
hparams.n_head_kv = hparams.n_head;
LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias);
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
+ LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
+ LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str());
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
- layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split);
- layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
- layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
+ layer.ffn_gate_inp = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd}, backend, false);
+
+ if (layer.ffn_gate_inp == nullptr) {
+ GGML_ASSERT(hparams.n_expert == 0);
+ GGML_ASSERT(hparams.n_expert_used == 0);
+
+ layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split);
+ layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
+ layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
+ } else {
+ GGML_ASSERT(hparams.n_expert > 0);
+ GGML_ASSERT(hparams.n_expert_used > 0);
+
+ // MoE branch
+ for (uint32_t x = 0; x < hparams.n_expert; ++x) {
+ layer.ffn_gate_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), {n_embd, n_ff}, backend_split);
+ layer.ffn_down_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), { n_ff, n_embd}, backend_split);
+ layer.ffn_up_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, x), {n_embd, n_ff}, backend_split);
+ }
+ }
if (backend == GGML_BACKEND_GPU) {
vram_weights +=
(layer.bk ? ggml_nbytes(layer.bk) : 0) +
(layer.bv ? ggml_nbytes(layer.bv) : 0) +
(layer.bo ? ggml_nbytes(layer.bo) : 0) +
- ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_gate) +
- ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up);
+ ggml_nbytes(layer.ffn_norm);
+
+ if (layer.ffn_gate_inp == nullptr) {
+ vram_weights +=
+ ggml_nbytes(layer.ffn_gate) + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up);
+ } else {
+ vram_weights += ggml_nbytes(layer.ffn_gate_inp);
+ for (uint32_t x = 0; x < hparams.n_expert; ++x) {
+ vram_weights +=
+ ggml_nbytes(layer.ffn_gate_exp[x]) + ggml_nbytes(layer.ffn_down_exp[x]) + ggml_nbytes(layer.ffn_up_exp[x]);
+ }
+ }
}
}
} break;
const int64_t n_head_kv;
const int64_t n_embd_head;
const int64_t n_embd_gqa;
+ const int64_t n_expert;
+ const int64_t n_expert_used;
const float freq_base;
const float freq_scale;
n_head_kv (hparams.n_head_kv),
n_embd_head (hparams.n_embd_head()),
n_embd_gqa (hparams.n_embd_gqa()),
+ n_expert (hparams.n_expert),
+ n_expert_used (hparams.n_expert_used),
freq_base (cparams.rope_freq_base),
freq_scale (cparams.rope_freq_scale),
ext_factor (cparams.yarn_ext_factor),
cb(ffn_inp, "ffn_inp", il);
// feed-forward network
- {
+ if (model.layers[il].ffn_gate_inp == nullptr) {
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, cb, il);
model.layers[il].ffn_down, NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
+ } else {
+ // MoE branch
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "ffn_norm", il);
+
+ ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
+ cb(logits, "ffn_moe_logits", il);
+
+ ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
+ cb(probs, "ffn_moe_probs", il);
+
+ // select experts
+ ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
+ cb(selected_experts->src[0], "ffn_moe_argsort", il);
+
+ ggml_tensor * weights = ggml_get_rows(ctx0,
+ ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
+ cb(weights, "ffn_moe_weights", il);
+
+ weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
+
+ ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
+ cb(weights_sum, "ffn_moe_weights_sum", il);
+
+ weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
+ cb(weights, "ffn_moe_weights_norm", il);
+
+ // compute expert outputs
+ ggml_tensor * moe_out = nullptr;
+
+ for (int i = 0; i < n_expert_used; ++i) {
+ ggml_tensor * cur_expert;
+
+ ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exp, n_expert, selected_experts, i, cur);
+ cb(cur_up, "ffn_moe_up", il);
+
+ ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exp, n_expert, selected_experts, i, cur);
+ cb(cur_gate, "ffn_moe_gate", il);
+
+ cur_gate = ggml_silu(ctx0, cur_gate);
+ cb(cur_gate, "ffn_moe_silu", il);
+
+ cur_expert = ggml_mul(ctx0, cur_up, cur_gate); // [n_tokens, n_embd]
+ cb(cur_expert, "ffn_moe_gate_par", il);
+
+ cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exp, n_expert, selected_experts, i, cur_expert); // [n_tokens, n_embd]
+ cb(cur_expert, "ffn_moe_down", il);
+
+ cur_expert = ggml_mul(ctx0, cur_expert,
+ ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
+ cb(cur_expert, "ffn_moe_weighted", il);
+
+ if (i == 0) {
+ moe_out = cur_expert;
+ } else {
+ moe_out = ggml_add(ctx0, moe_out, cur_expert);
+ cb(moe_out, "ffn_moe_out", il);
+ }
+ }
+
+ cur = moe_out;
}
cur = ggml_add(ctx0, cur, ffn_inp);
{ "ffn_relu", OFFLOAD_FUNC },
{ "ffn_sqr(relu)", OFFLOAD_FUNC },
+ { "ffn_moe_logits", OFFLOAD_FUNC },
+ { "ffn_moe_probs", OFFLOAD_FUNC },
+ { "ffn_moe_argsort", OFFLOAD_FUNC },
+ { "ffn_moe_weights", OFFLOAD_FUNC },
+ { "ffn_moe_weights_sum", OFFLOAD_FUNC },
+ { "ffn_moe_weights_norm", OFFLOAD_FUNC },
+ { "ffn_moe_weighted", OFFLOAD_FUNC },
+ { "ffn_moe_up", OFFLOAD_FUNC },
+ { "ffn_moe_gate", OFFLOAD_FUNC },
+ { "ffn_moe_silu", OFFLOAD_FUNC },
+ { "ffn_moe_gate_par", OFFLOAD_FUNC },
+ { "ffn_moe_down", OFFLOAD_FUNC },
+ { "ffn_moe_out", OFFLOAD_FUNC },
+
{ "l_out", OFFLOAD_FUNC },
{ "result_norm", OFFLOAD_FUNC_EMB },
workers.clear();
}
-static ggml_type get_k_quant_type(
- quantize_state_internal & qs,
- ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype
-) {
+static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
const std::string name = ggml_get_name(tensor);
+
// TODO: avoid hardcoded tensor names - use the TN_* constants
const llm_arch arch = qs.model.arch;
const auto tn = LLM_TN(arch);
// nearly negligible increase in model size by quantizing this tensor with more bits:
if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K;
}
+ if (qs.model.hparams.n_expert == 8) {
+ // for the 8-expert model, bumping this to Q8_0 trades just ~128MB
+ // TODO: explore better strategies
+ new_type = GGML_TYPE_Q8_0;
+ }
++qs.i_attention_wv;
+ } else if (name.find("attn_k.weight") != std::string::npos) {
+ if (qs.model.hparams.n_expert == 8) {
+ // for the 8-expert model, bumping this to Q8_0 trades just ~128MB
+ // TODO: explore better strategies
+ new_type = GGML_TYPE_Q8_0;
+ }
} else if (name.find("ffn_down.weight") != std::string::npos) {
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
quantize &= params->quantize_output_tensor || name != "output.weight";
quantize &= !params->only_copy;
+ // do not quantize expert gating tensors
+ quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;
+
enum ggml_type new_type;
void * new_data;
size_t new_size;
size_t size = ggml_nelements(tensor);
std::vector<float> data(size);
- std::random_device rd;
-
#if 0
std::default_random_engine generator(rd());
std::uniform_real_distribution<float> distribution(min, max);
}
#endif
auto init_thread = [&](size_t start, size_t end) {
+ std::random_device rd;
std::default_random_engine generator(rd());
std::uniform_real_distribution<float> distribution(min, max);
t.join();
}
- if (tensor->type == GGML_TYPE_F32) {
+ if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) {
ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float));
} else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16) {
GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0);
std::vector<uint8_t> buf(ggml_nbytes(t));
ggml_backend_tensor_get(t, buf.data(), 0, ggml_nbytes(t));
+ ggml_type_traits_t tt = ggml_internal_get_type_traits(t->type);
+ size_t bs = ggml_blck_size(t->type);
+
// access elements by index to avoid gaps in views
for (int64_t i3 = 0; i3 < t->ne[3]; i3++) {
for (int64_t i2 = 0; i2 < t->ne[2]; i2++) {
for (int64_t i1 = 0; i1 < t->ne[1]; i1++) {
- for (int64_t i0 = 0; i0 < t->ne[0]; i0++) {
- size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0*t->nb[0];
- float v;
+ for (int64_t i0 = 0; i0 < t->ne[0]; i0 += bs) {
+ size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0];
if (t->type == GGML_TYPE_F16) {
- v = (float) ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i]);
+ tv.push_back(ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i]));
} else if (t->type == GGML_TYPE_F32) {
- v = *(float *) &buf[i];
+ tv.push_back(*(float *) &buf[i]);
} else if (t->type == GGML_TYPE_I32) {
- v = *(int32_t *) &buf[i];
+ tv.push_back((float)*(int32_t *) &buf[i]);
+ } else if (ggml_is_quantized(t->type)) {
+ std::vector<float> vq(ggml_blck_size(t->type));
+ tt.to_float(&buf[i], vq.data(), ggml_blck_size(t->type));
+ tv.insert(tv.end(), vq.begin(), vq.end());
} else {
GGML_ASSERT(false);
}
- tv.push_back(v);
}
}
}
struct test_case {
virtual ~test_case() {}
+ virtual std::string op_desc(ggml_tensor * t) {
+ return ggml_op_desc(t);
+ }
+
virtual std::string vars() {
return "";
}
virtual ggml_tensor * build_graph(ggml_context * ctx) = 0;
virtual double max_nmse_err() {
- return 1e-6;
+ return 1e-7;
}
virtual void initialize_tensors(ggml_context * ctx) {
ggml_tensor * out = build_graph(ctx);
- if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) {
- //printf(" %s: skipping\n", ggml_op_desc(out));
+ if (op_name != nullptr && op_desc(out) != op_name) {
+ //printf(" %s: skipping\n", op_desc(out).c_str());
ggml_free(ctx);
return true;
}
- printf(" %s(%s): ", ggml_op_desc(out), vars().c_str());
+ printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str());
fflush(stdout);
// check if backends support op
for (size_t i = 0; i < f1.size(); i++) {
// check for nans
if (std::isnan(f1[i]) || std::isnan(f2[i])) {
- printf("NaN at index %zu ", i);
+ printf("[%s] NaN at index %zu (%f %f) ", ggml_op_desc(t1), i, f1[i], f2[i]);
ud->ok = false;
return true;
}
if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) {
if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) {
if (std::signbit(f1[i]) != std::signbit(f2[i])) {
- printf("inf sign mismatch: %f %f ", f1[i], f2[i]);
+ printf("[%s] inf sign mismatch: %f %f ", ggml_op_desc(t1), f1[i], f2[i]);
ud->ok = false;
return true;
}
} else {
- printf("inf mismatch: %f %f ", f1[i], f2[i]);
+ printf("[%s] inf mismatch: %f %f ", ggml_op_desc(t1), f1[i], f2[i]);
ud->ok = false;
return true;
}
double err = nmse(f1.data(), f2.data(), f1.size());
if (err > ud->max_err) {
- printf("NMSE = %f ", err);
+ printf("[%s] NMSE = %f ", ggml_op_desc(t1), err);
+ //for (int i = 0; i < f1.size(); i++) {
+ // printf("(%f, %f) ", f1[i], f2[i]);
+ //}
+ //printf("\n");
ud->ok = false;
}
return true;
+
+ GGML_UNUSED(index);
};
ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud);
ggml_tensor * out = build_graph(ctx);
- if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) {
- //printf(" %s: skipping\n", ggml_op_desc(out));
+ if (op_name != nullptr && op_desc(out) != op_name) {
+ //printf(" %s: skipping\n", op_desc(out).c_str());
ggml_free(ctx);
return true;
}
- int len = printf(" %s(%s): ", ggml_op_desc(out), vars().c_str());
+ int len = printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str());
fflush(stdout);
// check if backends support op
return size;
};
for (int i = 0; i < gf->n_nodes; i++) {
- if (ggml_is_view_op(gf->nodes[i]->op) || gf->nodes[i] == out)
+ if (ggml_is_view_op(gf->nodes[i]->op) || gf->nodes[i] == out) {
continue;
+ }
mem += tensor_op_size(gf->nodes[i]);
}
const int n; // cols
const int m; // rows
const int r; // rows to get
+ const int b; // batch size
+ const bool v; // view (non-contiguous src1)
std::string vars() override {
- return VARS_TO_STR4(type, n, m, r);
+ return VARS_TO_STR6(type, n, m, r, b, v);
}
- test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3)
- : type(type), n(n), m(m), r(r) {}
+ test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1, bool v = false)
+ : type(type), n(n), m(m), r(r), b(b), v(v) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
- ggml_tensor * in = ggml_new_tensor_2d(ctx, type, n, m);
- ggml_tensor * rows = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, r);
+ ggml_tensor * in = ggml_new_tensor_3d(ctx, type, n, m, b);
+ ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b);
+ if (v) {
+ rows = ggml_view_2d(ctx, rows, r/2, b, rows->nb[1], 0);
+ }
ggml_tensor * out = ggml_get_rows(ctx, in, rows);
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I32) {
+ if (ggml_is_view_op(t->op)) { continue; }
// rows
- std::vector<int> data(r);
- for (int i = 0; i < r; i++) {
+ std::vector<int> data(r*b);
+ for (int i = 0; i < r*b; i++) {
data[i] = rand() % m;
}
- ggml_backend_tensor_set(t, data.data(), 0, r * sizeof(int));
+ ggml_backend_tensor_set(t, data.data(), 0, r * b * sizeof(int));
} else {
init_tensor_uniform(t);
}
const int64_t m;
const int64_t n;
const int64_t k;
- const std::array<int64_t, 2> bs; // dims 3 and 4
- const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
+ const bool v; // view (non-contiguous ids)
std::string vars() override {
- return VARS_TO_STR9(type_a, type_b, n_mats, id, m, n, k, bs, nr);
+ return VARS_TO_STR8(type_a, type_b, n_mats, id, m, n, k, v);
}
double max_nmse_err() override {
}
size_t op_size(ggml_tensor * t) override {
- size_t a = ggml_nbytes(t->src[2]) * n * nr[0] * nr[1];
+ size_t a = ggml_nbytes(t->src[2]) * n;
size_t b = ggml_nbytes(t->src[1]) * m;
size_t c = ggml_nbytes(t);
return a + b + c;
test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
int n_mats = 2, int id = 0,
- int64_t m = 32, int64_t n = 32, int64_t k = 32,
- std::array<int64_t, 2> bs = {10, 10},
- std::array<int64_t, 2> nr = {2, 2})
+ int64_t m = 32, int64_t n = 32, int64_t k = 32, bool v = false)
: type_a(type_a), type_b(type_b), n_mats(n_mats), id(id),
- m(m), n(n), k(k), bs(bs), nr(nr) {}
+ m(m), n(n), k(k), v(v) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
// C^T = A * B^T: (k, m) * (k, n) => (m, n)
std::vector<ggml_tensor *> mats;
for (int i = 0; i < n_mats; i++) {
- ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
+ ggml_tensor * a = ggml_new_tensor_2d(ctx, type_a, k, m);
mats.push_back(a);
}
- ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_mats);
- ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
- ggml_tensor * out = ggml_mul_mat_id(ctx, mats.data(), ids, id, b);
+ ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n);
+ if (v) {
+ ids = ggml_view_2d(ctx, ids, n_mats/2, ids->ne[1], ids->nb[1], 0);
+ }
+ ggml_tensor * b = ggml_new_tensor_2d(ctx, type_b, k, n);
+ ggml_tensor * out = ggml_mul_mat_id(ctx, mats.data(), n_mats, ids, v ? id/2 : id, b);
return out;
}
void initialize_tensors(ggml_context * ctx) override {
+ std::random_device rd;
+ std::default_random_engine rng(rd());
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I32) {
+ if (ggml_is_view_op(t->op)) { continue; }
// ids
- std::vector<int> data(n_mats);
- for (int i = 0; i < n_mats; i++) {
- data[i] = i;
+ for (int64_t r = 0; r < ggml_nrows(t); r++) {
+ std::vector<int32_t> data(t->ne[0]);
+ for (int i = 0; i < t->ne[0]; i++) {
+ data[i] = i % n_mats;
+ }
+ std::shuffle(data.begin(), data.end(), rng);
+ ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
}
- std::shuffle(data.begin(), data.end(), std::default_random_engine(std::random_device()()));
- ggml_backend_tensor_set(t, data.data(), 0, n_mats * sizeof(int));
} else {
init_tensor_uniform(t);
}
}
};
+// Mixtral MOE
+struct test_moe : public test_case {
+ const int n_experts;
+ const int n_experts_per_tok;
+ const int n_tokens;
+ const int n_embd;
+ const int n_ff;
+
+ std::string op_desc(ggml_tensor * t) override {
+ return "MOE";
+
+ GGML_UNUSED(t);
+ }
+
+ std::string vars() override {
+ return VARS_TO_STR5(n_experts, n_experts_per_tok, n_tokens, n_embd, n_ff);
+ }
+
+ test_moe(int n_experts = 8, int n_experts_per_tok = 2, int n_tokens = 1, int n_embd = 4096, int n_ff = 14336)
+ : n_experts(n_experts), n_experts_per_tok(n_experts_per_tok), n_tokens(n_tokens), n_embd(n_embd), n_ff(n_ff) {
+ }
+
+ ggml_tensor * build_graph(ggml_context * ctx) override {
+ ggml_tensor * ffn_gate_inp = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_experts);
+
+ std::vector<ggml_tensor *> ffn_up_exp(n_experts);
+ std::vector<ggml_tensor *> ffn_gate_exp(n_experts);
+ std::vector<ggml_tensor *> ffn_down_exp(n_experts);
+
+ for (int i = 0; i < n_experts; ++i) {
+ ffn_up_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
+ ffn_gate_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
+ ffn_down_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd);
+ }
+
+ ggml_tensor * cur = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens);
+
+ ggml_tensor * logits = ggml_mul_mat(ctx, ffn_gate_inp, cur);
+ ggml_tensor * probs = ggml_soft_max_ext(ctx, logits, nullptr, 1.0f/sqrtf(n_embd));
+
+ // select experts
+ ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_experts_per_tok);
+
+ ggml_tensor * weights = ggml_get_rows(ctx,
+ ggml_reshape_3d(ctx, probs, 1, n_experts, n_tokens), selected_experts);
+
+ weights = ggml_reshape_2d(ctx, weights, n_experts_per_tok, n_tokens);
+
+ ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights);
+
+ weights = ggml_div(ctx, weights, weights_sum);
+
+ // compute expert outputs
+ ggml_tensor * moe_out = nullptr;
+
+ for (int i = 0; i < n_experts_per_tok; ++i) {
+ ggml_tensor * cur_expert;
+
+ ggml_tensor * cur_up = ggml_mul_mat_id(ctx, ffn_up_exp.data(), n_experts, selected_experts, i, cur);
+
+ ggml_tensor * cur_gate = ggml_mul_mat_id(ctx, ffn_gate_exp.data(), n_experts, selected_experts, i, cur);
+
+ cur_gate = ggml_silu(ctx, cur_gate);
+
+ cur_expert = ggml_mul(ctx, cur_up, cur_gate);
+
+ cur_expert = ggml_mul_mat_id(ctx, ffn_down_exp.data(), n_experts, selected_experts, i, cur_expert);
+
+ cur_expert = ggml_mul(ctx, cur_expert,
+ ggml_view_2d(ctx, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
+
+ if (i == 0) {
+ moe_out = cur_expert;
+ } else {
+ moe_out = ggml_add(ctx, moe_out, cur_expert);
+ }
+ }
+
+ cur = moe_out;
+
+ return cur;
+ }
+};
+
enum test_mode {
MODE_TEST,
MODE_PERF,
static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
std::vector<std::unique_ptr<test_case>> test_cases;
+ const ggml_type all_types[] = {
+ GGML_TYPE_F32, GGML_TYPE_F16,
+ GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
+ GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
+ GGML_TYPE_Q8_0,
+ GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
+ GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
+ GGML_TYPE_Q6_K
+ };
+
// unary ops
for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
test_cases.emplace_back(new test_unary((ggml_unary_op) op));
}
- for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
- test_cases.emplace_back(new test_get_rows(type, 10, 5, 3));
- test_cases.emplace_back(new test_get_rows(type, 16, 5, 3));
+ test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));
+ for (ggml_type type : all_types) {
+ for (int b : {1, 7}) {
+ for (bool v : {false, true}) {
+ test_cases.emplace_back(new test_get_rows(type, 256, 5, 4, b, v));
+ }
+ }
}
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1}));
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 2}));
test_cases.emplace_back(new test_dup());
- test_cases.emplace_back(new test_cpy());
+
+ for (ggml_type type : all_types) {
+ test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, type, {256, 10, 10, 1}));
+ }
+
test_cases.emplace_back(new test_cont());
auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
};
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 8, 1}, {1, 1, 1, 1});
+ add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1, 1}, {32, 1, 1, 1});
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 320, 320}, {1, 1, 1, 1});
add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 1, 1}, {1, 1, 1, 1});
add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 1}, {1, 1, 1, 1});
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 640, 1}, {32, 32, 1, 1});
add_test_bin_bcast(GGML_TYPE_F32, {5120, 1, 1, 1}, {1, 256, 1, 1});
add_test_bin_bcast(GGML_TYPE_F32, {640, 1, 1, 1}, {1, 1, 1, 1});
- add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {1, 1, 1, 1});
- add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {2, 1, 1, 1});
+ //add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {1, 1, 1, 1});
+ //add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {2, 1, 1, 1});
test_cases.emplace_back(new test_scale());
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
}
- const ggml_type all_types[] = {
- GGML_TYPE_F32, GGML_TYPE_F16,
- GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
- GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
- GGML_TYPE_Q8_0,
- GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
- GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
- GGML_TYPE_Q6_K
- };
-
for (ggml_type type_a : all_types) {
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
// FIXME: CPU crashes on f16xf16
for (ggml_type type_a : all_types) {
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
- for (int n_mats : {1, 2, 4}) {
+ for (int n_mats : {2, 4, 8}) {
for (int id = 0; id < n_mats; id++) {
- test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, {1, 1}, {1, 1}));
+ for (bool v : {false, true}) {
+ test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, v));
+ }
}
}
}
test_cases.emplace_back(new test_concat());
for (ggml_sort_order order : {GGML_SORT_ASC, GGML_SORT_DESC}) {
+ test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
}
- test_cases.emplace_back(new test_sum_rows());
+ test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, {10, 10, 10, 10}));
+ test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, {2, 1, 1, 1}));
+
+#if !defined(__SANITIZE_THREAD__)
+ // FIXME: these tests use too much memory with thread sanitizer
+ test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 14336));
+ //test_cases.emplace_back(new test_moe(8, 2, 8, 4096, 14336));
+#endif
// run tests
if (mode == MODE_TEST) {
ggml_backend_free(backend_cpu);
return n_ok == test_cases.size();
- } else if (mode == MODE_PERF) {
+ }
+
+ if (mode == MODE_PERF) {
for (auto & test : test_cases) {
test->eval_perf(backend, op_name);
}
return true;
- } else {
- GGML_ASSERT(false);
}
+
+ GGML_ASSERT(false);
+ return false;
}
static void usage(char ** argv) {
}
printf("%zu/%zu backends passed\n", n_ok, ggml_backend_reg_get_count());
+
if (n_ok != ggml_backend_reg_get_count()) {
printf("\033[1;31mFAIL\033[0m\n");
return 1;
- } else {
- printf("\033[1;32mOK\033[0m\n");
- return 0;
}
+
+ printf("\033[1;32mOK\033[0m\n");
+ return 0;
}