GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
+ GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t);
GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t);
GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t);
GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t);
GGML_OP_TRANSPOSE,
GGML_OP_GET_ROWS,
GGML_OP_GET_ROWS_BACK,
+ GGML_OP_SET_ROWS,
GGML_OP_DIAG,
GGML_OP_DIAG_MASK_INF,
GGML_OP_DIAG_MASK_ZERO,
// true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor);
+ // true if the elements in dimension 0 are contiguous, or there is just 1 block of elements
+ GGML_API bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor);
+
GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
struct ggml_tensor * b, // row indices
struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape
+ // a TD [n_embd, ne1, ne2, ne3]
+ // b TS [n_embd, n_rows, ne02, ne03] | ne02 == ne2, ne03 == ne3
+ // c I64 [n_rows, ne11, ne12, 1] | c[i] in [0, ne1)
+ //
+ // undefined behavior if destination rows overlap
+ //
+ // broadcast:
+ // ne2 % ne11 == 0
+ // ne3 % ne12 == 0
+ //
+ // return view(a)
+ GGML_API struct ggml_tensor * ggml_set_rows(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a, // destination
+ struct ggml_tensor * b, // source
+ struct ggml_tensor * c); // row indices
+
GGML_API struct ggml_tensor * ggml_diag(
struct ggml_context * ctx,
struct ggml_tensor * a);
static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
[GGML_TYPE_F32] = {
+ .from_float = (ggml_from_float_t) ggml_cpu_fp32_to_fp32,
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
.vec_dot_type = GGML_TYPE_F32,
.nrows = 1,
{
ggml_compute_forward_get_rows_back(params, tensor);
} break;
+ case GGML_OP_SET_ROWS:
+ {
+ ggml_compute_forward_set_rows(params, tensor);
+ } break;
case GGML_OP_DIAG:
{
ggml_compute_forward_diag(params, tensor);
n_tasks = n_threads;
} break;
case GGML_OP_GET_ROWS:
+ case GGML_OP_SET_ROWS:
{
// FIXME: get_rows can use additional threads, but the cost of launching additional threads
// decreases performance with GPU offloading
return ggml_graph_compute(cgraph, &cplan);
}
+void ggml_cpu_fp32_to_fp32(const float * x, float * y, int64_t n) {
+ memcpy(y, x, n * sizeof(float));
+}
+
void ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) {
int64_t i = 0;
#if defined(__F16C__)
switch (op->op) {
case GGML_OP_CPY:
+ case GGML_OP_SET_ROWS:
return
op->type != GGML_TYPE_IQ3_XXS &&
op->type != GGML_TYPE_IQ3_S &&
if (ggml_is_contiguous(dst)) {
// TODO: simplify
if (nb00 == sizeof(float)) {
- if (dst->type == GGML_TYPE_F32) {
- size_t id = 0;
- const size_t rs = ne00 * nb00;
- char * dst_ptr = (char *) dst->data;
-
- for (int i03 = 0; i03 < ne03; i03++) {
- for (int i02 = 0; i02 < ne02; i02++) {
- id += rs * ir0;
- for (int i01 = ir0; i01 < ir1; i01++) {
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
- memcpy(dst_ptr + id, src0_ptr, rs);
- id += rs;
- }
- id += rs * (ne01 - ir1);
- }
- }
- } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
- ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
+ if (ggml_get_type_traits_cpu(dst->type)->from_float) {
+ ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
size_t id = 0;
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
id += rs * ir0;
for (int i01 = ir0; i01 < ir1; i01++) {
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
- quantize_row_q(src0_ptr, dst_ptr + id, ne00);
+ from_float(src0_ptr, dst_ptr + id, ne00);
id += rs;
}
id += rs * (ne01 - ir1);
{
ggml_compute_forward_repeat_f32(params, dst);
} break;
+ // TODO: templateify the implemenation and support for I64
+ // ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225
+ //case GGML_TYPE_I64:
+ // {
+ // ggml_compute_forward_repeat_i64(params, dst);
+ // } break;
default:
{
GGML_ABORT("fatal error");
//}
}
+static void ggml_compute_forward_set_rows_f32(
+ const ggml_compute_params * params,
+ ggml_tensor * dst) {
+
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const int64_t nc = ne00;
+ const int64_t nr = ne01;
+
+ assert(ne0 == nc);
+ assert(ne2 == ne02);
+ assert(ne3 == ne03);
+ assert(src0->type == GGML_TYPE_F32);
+ assert(ne02 % ne11 == 0);
+ assert(ne03 % ne12 == 0);
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ // rows per thread
+ const int64_t dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int64_t ir0 = dr*ith;
+ const int64_t ir1 = std::min(ir0 + dr, nr);
+
+ ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
+
+ for (int64_t i03 = 0; i03 < ne03; ++i03) {
+ for (int64_t i02 = 0; i02 < ne02; ++i02) {
+ for (int64_t i = ir0; i < ir1; ++i) {
+ const int64_t i12 = i03%ne12;
+ const int64_t i11 = i02%ne11;
+ const int64_t i10 = i;
+
+ const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+ GGML_ASSERT(i1 >= 0 && i1 < ne1);
+
+ from_float(
+ (const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03),
+ ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc);
+ }
+ }
+ }
+}
+
+void ggml_compute_forward_set_rows(
+ const ggml_compute_params * params,
+ ggml_tensor * dst) {
+
+ const ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_set_rows_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ABORT("src0->type = %d (%s) not supported", src0->type, ggml_type_name(src0->type));
+ }
+ }
+}
+
// ggml_compute_forward_get_rows_back
static void ggml_compute_forward_get_rows_back_f32_f16(
void ggml_compute_forward_transpose(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_get_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_get_rows_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_set_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_diag(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_diag_mask_inf(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_diag_mask_zero(const struct ggml_compute_params * params, struct ggml_tensor * dst);
uint64_t nb2;
} ggml_metal_kargs_get_rows;
+typedef struct {
+ int32_t nk0;
+ int32_t ne01;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne11;
+ int32_t ne12;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+} ggml_metal_kargs_set_rows;
+
typedef struct {
int64_t ne00;
int64_t ne01;
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
GGML_METAL_KERNEL_TYPE_RMS_NORM,
GGML_METAL_KERNEL_TYPE_L2_NORM,
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, set_rows_q8_0, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, set_rows_q4_0, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, set_rows_q4_1, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
const bool use_bfloat = ctx_dev->use_bfloat;
if (!use_bfloat) {
+ if (op->type == GGML_TYPE_BF16) {
+ return false;
+ }
+
for (size_t i = 0, n = 3; i < n; ++i) {
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
return false;
{
return op->ne[3] == 1;
}
+ case GGML_OP_SET_ROWS:
+ {
+ if (op->src[0]->type != GGML_TYPE_F32) {
+ return false;
+ }
+
+ switch (op->type) {
+ case GGML_TYPE_F32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_BF16:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_IQ4_NL:
+ return true;
+ default:
+ return false;
+ };
+ }
default:
return false;
}
};
[encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&args length:sizeof(args) atIndex:3];
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
} break;
+ case GGML_OP_SET_ROWS:
+ {
+ id<MTLComputePipelineState> pipeline = nil;
+
+ switch (dst->type) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline; break;
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline; break;
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0 ].pipeline; break;
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0 ].pipeline; break;
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1 ].pipeline; break;
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0 ].pipeline; break;
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1 ].pipeline; break;
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline; break;
+ default: GGML_ABORT("not implemented");
+ }
+
+ const int32_t nk0 = ne0/ggml_blck_size(dst->type);
+
+ int nth = 32; // SIMD width
+
+ while (nth < nk0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
+ nth *= 2;
+ }
+
+ int nrptg = 1;
+ if (nth > nk0) {
+ nrptg = (nth + nk0 - 1)/nk0;
+ nth = nk0;
+
+ if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
+ nrptg--;
+ }
+ }
+
+ nth = MIN(nth, nk0);
+
+ ggml_metal_kargs_set_rows args = {
+ /*.nk0 =*/ nk0,
+ /*.ne01 =*/ ne01,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne11 =*/ ne11,
+ /*.ne12 =*/ ne12,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ };
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
+
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
+ } break;
case GGML_OP_RMS_NORM:
{
GGML_ASSERT(ne00 % 4 == 0);
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
};
+static inline int best_index_int8(int n, constant float * val, float x) {
+ if (x <= val[0]) return 0;
+ if (x >= val[n-1]) return n-1;
+ int ml = 0, mu = n-1;
+ while (mu-ml > 1) {
+ int mav = (ml+mu)/2;
+ if (x < val[mav]) mu = mav; else ml = mav;
+ }
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
+}
+
// NOTE: this is not dequantizing - we are simply fitting the template
template <typename type4x4>
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
}
}
+void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
+ float amax = 0.0f; // absolute max
+ float max = 0.0f;
+
+ for (int j = 0; j < QK4_0; j++) {
+ const float v = src[j];
+ if (amax < fabs(v)) {
+ amax = fabs(v);
+ max = v;
+ }
+ }
+
+ const float d = max / -8;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dst.d = d;
+
+ for (int j = 0; j < QK4_0/2; ++j) {
+ const float x0 = src[0 + j]*id;
+ const float x1 = src[QK4_0/2 + j]*id;
+
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
+
+ dst.qs[j] = xi0;
+ dst.qs[j] |= xi1 << 4;
+ }
+}
+
+void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
+ float min = FLT_MAX;
+ float max = -FLT_MAX;
+
+ for (int j = 0; j < QK4_1; j++) {
+ const float v = src[j];
+ if (min > v) min = v;
+ if (max < v) max = v;
+ }
+
+ const float d = (max - min) / ((1 << 4) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dst.d = d;
+ dst.m = min;
+
+ for (int j = 0; j < QK4_1/2; ++j) {
+ const float x0 = (src[0 + j] - min)*id;
+ const float x1 = (src[QK4_1/2 + j] - min)*id;
+
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
+
+ dst.qs[j] = xi0;
+ dst.qs[j] |= xi1 << 4;
+ }
+}
+
+void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
+ float amax = 0.0f; // absolute max
+ float max = 0.0f;
+
+ for (int j = 0; j < QK5_0; j++) {
+ const float v = src[j];
+ if (amax < fabs(v)) {
+ amax = fabs(v);
+ max = v;
+ }
+ }
+
+ const float d = max / -16;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dst.d = d;
+
+ uint32_t qh = 0;
+ for (int j = 0; j < QK5_0/2; ++j) {
+ const float x0 = src[0 + j]*id;
+ const float x1 = src[QK5_0/2 + j]*id;
+
+ const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
+ const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
+
+ dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
+ }
+
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
+
+ for (int j = 0; j < 4; ++j) {
+ dst.qh[j] = qh8[j];
+ }
+}
+
+void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
+ float max = src[0];
+ float min = src[0];
+
+ for (int j = 1; j < QK5_1; j++) {
+ const float v = src[j];
+ min = v < min ? v : min;
+ max = v > max ? v : max;
+ }
+
+ const float d = (max - min) / 31;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dst.d = d;
+ dst.m = min;
+
+ uint32_t qh = 0;
+ for (int j = 0; j < QK5_1/2; ++j) {
+ const float x0 = (src[0 + j] - min)*id;
+ const float x1 = (src[QK5_1/2 + j] - min)*id;
+
+ const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
+ const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
+
+ dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
+ }
+
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
+
+ for (int j = 0; j < 4; ++j) {
+ dst.qh[j] = qh8[j];
+ }
+}
+
+void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
+ float amax = 0.0f; // absolute max
+ float max = 0.0f;
+
+ for (int j = 0; j < QK4_NL; j++) {
+ const float v = src[j];
+ if (amax < fabs(v)) {
+ amax = fabs(v);
+ max = v;
+ }
+ }
+
+ const float d = max / kvalues_iq4nl_f[0];
+ const float id = d ? 1.0f/d : 0.0f;
+
+ float sumqx = 0, sumq2 = 0;
+ for (int j = 0; j < QK4_NL/2; ++j) {
+ const float x0 = src[0 + j]*id;
+ const float x1 = src[QK4_NL/2 + j]*id;
+
+ const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
+ const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
+
+ dst.qs[j] = xi0 | (xi1 << 4);
+
+ const float v0 = kvalues_iq4nl_f[xi0];
+ const float v1 = kvalues_iq4nl_f[xi1];
+ const float w0 = src[0 + j]*src[0 + j];
+ const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
+ sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
+ sumq2 += w0*v0*v0 + w1*v1*v1;
+
+ }
+
+ dst.d = sumq2 > 0 ? sumqx/sumq2 : d;
+}
+
template <typename type4x4>
void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
}
}
+void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
+ float amax = 0.0f; // absolute max
+
+ for (int j = 0; j < QK8_0; j++) {
+ const float v = src[j];
+ amax = MAX(amax, fabs(v));
+ }
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dst.d = d;
+
+ for (int j = 0; j < QK8_0; ++j) {
+ const float x0 = src[j]*id;
+
+ dst.qs[j] = round(x0);
+ }
+}
+
template <typename type4x4>
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
const float d = xb->d;
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
#endif
+// TODO: templetify these kernels
kernel void kernel_cpy_f32_q8_0(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
- float amax = 0.0f; // absolute max
-
- for (int j = 0; j < QK8_0; j++) {
- const float v = src[j];
- amax = MAX(amax, fabs(v));
- }
-
- const float d = amax / ((1 << 7) - 1);
- const float id = d ? 1.0f/d : 0.0f;
-
- dst_data[i00/QK8_0].d = d;
-
- for (int j = 0; j < QK8_0; ++j) {
- const float x0 = src[j]*id;
-
- dst_data[i00/QK8_0].qs[j] = round(x0);
- }
+ quantize_q8_0(src, dst_data[i00/QK8_0]);
}
}
for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
- float amax = 0.0f; // absolute max
- float max = 0.0f;
-
- for (int j = 0; j < QK4_0; j++) {
- const float v = src[j];
- if (amax < fabs(v)) {
- amax = fabs(v);
- max = v;
- }
- }
-
- const float d = max / -8;
- const float id = d ? 1.0f/d : 0.0f;
-
- dst_data[i00/QK4_0].d = d;
-
- for (int j = 0; j < QK4_0/2; ++j) {
- const float x0 = src[0 + j]*id;
- const float x1 = src[QK4_0/2 + j]*id;
-
- const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
- const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
-
- dst_data[i00/QK4_0].qs[j] = xi0;
- dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
- }
+ quantize_q4_0(src, dst_data[i00/QK4_0]);
}
}
for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
- float min = FLT_MAX;
- float max = -FLT_MAX;
-
- for (int j = 0; j < QK4_1; j++) {
- const float v = src[j];
- if (min > v) min = v;
- if (max < v) max = v;
- }
-
- const float d = (max - min) / ((1 << 4) - 1);
- const float id = d ? 1.0f/d : 0.0f;
-
- dst_data[i00/QK4_1].d = d;
- dst_data[i00/QK4_1].m = min;
-
- for (int j = 0; j < QK4_1/2; ++j) {
- const float x0 = (src[0 + j] - min)*id;
- const float x1 = (src[QK4_1/2 + j] - min)*id;
-
- const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
- const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
-
- dst_data[i00/QK4_1].qs[j] = xi0;
- dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
- }
+ quantize_q4_1(src, dst_data[i00/QK4_1]);
}
}
for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
- float amax = 0.0f; // absolute max
- float max = 0.0f;
-
- for (int j = 0; j < QK5_0; j++) {
- const float v = src[j];
- if (amax < fabs(v)) {
- amax = fabs(v);
- max = v;
- }
- }
-
- const float d = max / -16;
- const float id = d ? 1.0f/d : 0.0f;
-
- dst_data[i00/QK5_0].d = d;
-
- uint32_t qh = 0;
- for (int j = 0; j < QK5_0/2; ++j) {
- const float x0 = src[0 + j]*id;
- const float x1 = src[QK5_0/2 + j]*id;
-
- const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
- const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
-
- dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
- qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
- qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
- }
- thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
- for (int j = 0; j < 4; ++j) {
- dst_data[i00/QK5_0].qh[j] = qh8[j];
- }
+ quantize_q5_0(src, dst_data[i00/QK5_0]);
}
}
for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
- float max = src[0];
- float min = src[0];
-
- for (int j = 1; j < QK5_1; j++) {
- const float v = src[j];
- min = v < min ? v : min;
- max = v > max ? v : max;
- }
-
- const float d = (max - min) / 31;
- const float id = d ? 1.0f/d : 0.0f;
-
- dst_data[i00/QK5_1].d = d;
- dst_data[i00/QK5_1].m = min;
-
- uint32_t qh = 0;
- for (int j = 0; j < QK5_1/2; ++j) {
- const float x0 = (src[0 + j] - min)*id;
- const float x1 = (src[QK5_1/2 + j] - min)*id;
-
- const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
- const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
-
- dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
- qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
- qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
- }
- thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
- for (int j = 0; j < 4; ++j) {
- dst_data[i00/QK5_1].qh[j] = qh8[j];
- }
- }
-}
-
-static inline int best_index_int8(int n, constant float * val, float x) {
- if (x <= val[0]) return 0;
- if (x >= val[n-1]) return n-1;
- int ml = 0, mu = n-1;
- while (mu-ml > 1) {
- int mav = (ml+mu)/2;
- if (x < val[mav]) mu = mav; else ml = mav;
+ quantize_q5_1(src, dst_data[i00/QK5_1]);
}
- return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
kernel void kernel_cpy_f32_iq4_nl(
for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
- float amax = 0.0f; // absolute max
- float max = 0.0f;
-
- for (int j = 0; j < QK4_NL; j++) {
- const float v = src[j];
- if (amax < fabs(v)) {
- amax = fabs(v);
- max = v;
- }
- }
-
- const float d = max / kvalues_iq4nl_f[0];
- const float id = d ? 1.0f/d : 0.0f;
-
- float sumqx = 0, sumq2 = 0;
- for (int j = 0; j < QK4_NL/2; ++j) {
- const float x0 = src[0 + j]*id;
- const float x1 = src[QK4_NL/2 + j]*id;
-
- const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
- const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
-
- dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
-
- const float v0 = kvalues_iq4nl_f[xi0];
- const float v1 = kvalues_iq4nl_f[xi1];
- const float w0 = src[0 + j]*src[0 + j];
- const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
- sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
- sumq2 += w0*v0*v0 + w1*v1*v1;
-
- }
-
- dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
+ quantize_iq4_nl(src, dst_data[i00/QK4_NL]);
}
}
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
kernel void kernel_get_rows_q(
+ constant ggml_metal_kargs_get_rows & args,
device const void * src0,
device const void * src1,
device float * dst,
- constant ggml_metal_kargs_get_rows & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) {
template<typename T>
kernel void kernel_get_rows_f(
+ constant ggml_metal_kargs_get_rows & args,
device const void * src0,
device const void * src1,
device float * dst,
- constant ggml_metal_kargs_get_rows & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) {
}
kernel void kernel_get_rows_i32(
+ constant ggml_metal_kargs_get_rows & args,
device const void * src0,
device const void * src1,
device int32_t * dst,
- constant ggml_metal_kargs_get_rows & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) {
}
}
+template<typename block_q, void (*quantize_func)(device const float *, device block_q &)>
+kernel void kernel_set_rows_q32(
+ constant ggml_metal_kargs_set_rows & args,
+ device const void * src0,
+ device const void * src1,
+ device float * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg [[threads_per_threadgroup]]) {
+ const int32_t i03 = tgpig.z;
+ const int32_t i02 = tgpig.y;
+
+ const int32_t i12 = i03%args.ne12;
+ const int32_t i11 = i02%args.ne11;
+
+ const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
+ if (i01 >= args.ne01) {
+ return;
+ }
+
+ const int32_t i10 = i01;
+ const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
+
+ device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
+ const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
+
+ for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
+ quantize_func(src_row + 32*ind, dst_row[ind]);
+ }
+}
+
+template<typename T>
+kernel void kernel_set_rows_f(
+ constant ggml_metal_kargs_set_rows & args,
+ device const void * src0,
+ device const void * src1,
+ device float * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg [[threads_per_threadgroup]]) {
+ const int32_t i03 = tgpig.z;
+ const int32_t i02 = tgpig.y;
+
+ const int32_t i12 = i03%args.ne12;
+ const int32_t i11 = i02%args.ne11;
+
+ const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
+ if (i01 >= args.ne01) {
+ return;
+ }
+
+ const int32_t i10 = i01;
+ const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
+
+ device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
+ const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
+
+ for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
+ dst_row[ind] = (T) src_row[ind];
+ }
+}
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
+//
+// set rows
+//
+
+typedef decltype(kernel_set_rows_f<float>) set_rows_f_t;
+
+template [[host_name("kernel_set_rows_f32")]] kernel set_rows_f_t kernel_set_rows_f<float>;
+template [[host_name("kernel_set_rows_f16")]] kernel set_rows_f_t kernel_set_rows_f<half>;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_set_rows_bf16")]] kernel set_rows_f_t kernel_set_rows_f<bfloat>;
+#endif
+
+typedef decltype(kernel_set_rows_q32<block_q8_0, quantize_q8_0>) set_rows_q32_t;
+
+template [[host_name("kernel_set_rows_q8_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q8_0, quantize_q8_0>;
+template [[host_name("kernel_set_rows_q4_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_0, quantize_q4_0>;
+template [[host_name("kernel_set_rows_q4_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_1, quantize_q4_1>;
+template [[host_name("kernel_set_rows_q5_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_0, quantize_q5_0>;
+template [[host_name("kernel_set_rows_q5_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_1, quantize_q5_1>;
+template [[host_name("kernel_set_rows_iq4_nl")]] kernel set_rows_q32_t kernel_set_rows_q32<block_iq4_nl, quantize_iq4_nl>;
+
//
// matrix-matrix multiplication
//
"TRANSPOSE",
"GET_ROWS",
"GET_ROWS_BACK",
+ "SET_ROWS",
"DIAG",
"DIAG_MASK_INF",
"DIAG_MASK_ZERO",
"OPT_STEP_ADAMW",
};
-static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
+static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
"transpose(x)",
"get_rows(x)",
"get_rows_back(x)",
+ "set_rows(x)",
"diag(x)",
"diag_mask_inf(x)",
"diag_mask_zero(x)",
"adamw(x)",
};
-static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
+static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
tensor->nb[2] == ggml_type_size(tensor->type);
}
+bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor) {
+ return
+ tensor->ne[0] == ggml_blck_size(tensor->type) ||
+ tensor->nb[0] == ggml_type_size(tensor->type);
+}
+
static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return result;
}
+// ggml_set_rows
+
+struct ggml_tensor * ggml_set_rows(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * c) {
+ GGML_ASSERT(a->ne[0] == b->ne[0]);
+ GGML_ASSERT(a->ne[2] == b->ne[2]);
+ GGML_ASSERT(a->ne[3] == b->ne[3]);
+ GGML_ASSERT(b->ne[1] == c->ne[0]);
+ GGML_ASSERT(b->ne[2] % c->ne[1] == 0);
+ GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
+ GGML_ASSERT(c->ne[3] == 1);
+ GGML_ASSERT(b->type == GGML_TYPE_F32);
+ GGML_ASSERT(c->type == GGML_TYPE_I64);
+
+ GGML_ASSERT(ggml_is_contiguous_rows(a));
+ GGML_ASSERT(ggml_is_contiguous_rows(b));
+
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+ result->op = GGML_OP_SET_ROWS;
+ result->src[0] = b;
+ result->src[1] = c;
+
+ return result;
+}
+
// ggml_diag
struct ggml_tensor * ggml_diag(
}
};
+// GGML_OP_SET_ROWS
+struct test_set_rows : public test_case {
+ const ggml_type type;
+ const std::array<int64_t, 4> ne;
+ const std::array<int, 2> nr23; // broadcast only dims 2 and 3
+ const int r; // rows to set
+ const bool v; // view (non-contiguous src1)
+
+ std::string vars() override {
+ return VARS_TO_STR5(type, ne, nr23, r, v);
+ }
+
+ test_set_rows(ggml_type type,
+ std::array<int64_t, 4> ne,
+ std::array<int, 2> nr23,
+ int r, bool v = false)
+ : type(type), ne(ne), nr23(nr23), r(r), v(v) {}
+
+ ggml_tensor * build_graph(ggml_context * ctx) override {
+ ggml_tensor * dst = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]);
+ ggml_set_name(dst, "dst");
+
+ ggml_tensor * src = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], r, ne[2]*nr23[0], ne[3]*nr23[1]);
+ ggml_set_name(src, "src");
+
+ ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, GGML_TYPE_I64, r, ne[2], ne[3]);
+ ggml_set_name(row_idxs, "row_idxs");
+
+ if (v) {
+ src = ggml_view_4d(ctx, src, ne[0], r/2, ne[2]*nr23[0], ne[3]*nr23[1], src->nb[1], src->nb[2], src->nb[3], 0);
+ row_idxs = ggml_view_3d(ctx, row_idxs, r/2, ne[2], ne[3], row_idxs->nb[1], row_idxs->nb[2], 0);
+ ggml_set_name(row_idxs, "view_of_rows");
+ }
+
+ ggml_tensor * out = ggml_set_rows(ctx, dst, src, row_idxs);
+ ggml_set_name(out, "out");
+
+ return out;
+ }
+
+ void initialize_tensors(ggml_context * ctx) override {
+ std::random_device rd;
+ std::default_random_engine rng(rd());
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+ if (t->type == GGML_TYPE_I64) {
+ if (ggml_is_view_op(t->op)) {
+ continue;
+ }
+
+ for (int i2 = 0; i2 < t->ne[2]; i2++) {
+ for (int i1 = 0; i1 < t->ne[1]; i1++) {
+ // generate a shuffled subset of row indices
+ std::vector<int64_t> data(ne[1]);
+ for (int i = 0; i < ne[1]; i++) {
+ data[i] = i;
+ }
+ std::shuffle(data.begin(), data.end(), rng);
+ data.resize(t->ne[0]);
+
+ const size_t offs = i1*t->nb[1] + i2*t->nb[2];
+ ggml_backend_tensor_set(t, data.data(), offs, t->ne[0]*sizeof(int64_t));
+ }
+ }
+ } else {
+ init_tensor_uniform(t);
+ }
+ }
+ }
+};
+
// GGML_OP_ARGMAX
struct test_argmax : public test_case {
const ggml_type type;
test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_I32, 256, 5, 4, 1, v));
}
+ test_cases.emplace_back(new test_set_rows(GGML_TYPE_F32, { 1, 8, 1, 3 }, { 1, 1 }, 2, false));
+ for (ggml_type type : all_types) {
+ for (int b : {1, 7}) {
+ for (bool v : {false, true}) {
+ test_cases.emplace_back(new test_set_rows(type, { 256, 5, b, 3 }, { 1, 1, }, 1, v));
+ test_cases.emplace_back(new test_set_rows(type, { 256, 11, 1, b }, { 2, 3, }, 7, v));
+
+ test_cases.emplace_back(new test_set_rows(type, { 3*ggml_blck_size(type), 3, b, 1 }, { 2, 3, }, 2, v));
+
+ if (ggml_blck_size(type) == 1) {
+ test_cases.emplace_back(new test_set_rows(type, { 31, 3, b, 1 }, { 2, 3, }, 2, v));
+ test_cases.emplace_back(new test_set_rows(type, { 33, 5, 1, b }, { 2, 3, }, 1, v));
+ }
+ }
+ }
+ }
+
for (ggml_type type_input : {GGML_TYPE_F32}) {
for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
for (int k0 : {1, 3}) {