device const void * src0,
device const float * src1,
device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- int64_t ne10,
- int64_t ne12,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup int8_t * shared_values,
uint3 tgpig, uint tiisg, uint sgitg) {
const int nb = ne00/QK4_0;
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mv_q4_1_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mv_q5_0_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mv_q5_1_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
}
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
}
#define N_F32_F32 4
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
#if QK_K == 256
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
#if QK_K == 256
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
void kernel_mul_mv_q5_K_f32_impl(
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
void kernel_mul_mv_q6_K_f32_impl(
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
// ======================= "True" 2-bit
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
- threadgroup float * shared_values [[threadgroup(0)]],
+ threadgroup int8_t * shared_values_i8 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
const int nb = ne00/QK4_NL;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
- threadgroup float * shared_values [[threadgroup(0)]],
+ threadgroup int8_t * shared_values_i8 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
+ threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
[[host_name("kernel_mul_mv_iq1_m_f32")]]
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
- threadgroup float * shared_values [[threadgroup(0)]],
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
- threadgroup float * shared_values [[threadgroup(0)]],
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
// matrix-vector multiplication
//
-[[host_name("kernel_mul_mv_id_f32_f32")]]
-kernel void kernel_mul_mv_id_f32_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_f32_f32_impl(
- src0,
- src1 + bid*nb11,
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- nb00,
- nb01,
- nb02,
- ne10,
- ne11,
- ne12,
- nb10,
- nb11,
- nb12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg);
-}
-
-[[host_name("kernel_mul_mv_id_f16_f32")]]
-kernel void kernel_mul_mv_id_f16_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
+typedef void (kernel_mul_mv_impl_t)(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]);
- kernel_mul_mv_f16_f32_impl(
- src0,
- src1 + bid*nb11,
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- nb00,
- nb01,
- nb02,
- ne10,
- ne11,
- ne12,
- nb10,
- nb11,
- nb12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg);
-}
+typedef void (kernel_mul_mv2_impl_t)(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]);
-[[host_name("kernel_mul_mv_id_q8_0_f32")]]
-kernel void kernel_mul_mv_id_q8_0_f32(
- device const char * src0s,
+template<kernel_mul_mv_impl_t impl_fn>
+void mmv_fn(
+ device const char * src0,
device const char * src1,
device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
- constant int & idx,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_q8_0_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
+ impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg);
}
-[[host_name("kernel_mul_mv_id_q4_0_f32")]]
-kernel void kernel_mul_mv_id_q4_0_f32(
- device const char * src0s,
+template<kernel_mul_mv2_impl_t impl_fn>
+void mmv_fn(
+ device const char * src0,
device const char * src1,
device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
- constant int & idx,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
+ impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
}
-[[host_name("kernel_mul_mv_id_q4_1_f32")]]
-kernel void kernel_mul_mv_id_q4_1_f32(
- device const char * src0s,
+typedef void (mul_mv_impl_fn_t)(
+ device const char * src0,
device const char * src1,
device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
- constant int & idx,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
+ uint sgitg[[simdgroup_index_in_threadgroup]]);
-[[host_name("kernel_mul_mv_id_q5_0_f32")]]
-kernel void kernel_mul_mv_id_q5_0_f32(
+template<mul_mv_impl_fn_t impl_fn>
+kernel void kernel_mul_mv_id(
device const char * src0s,
device const char * src1,
device float * dst,
constant uint & r2,
constant uint & r3,
constant int & idx,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
device const char * src0 = src0s + id*nb02;
- mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+ impl_fn(
src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
+ src1 + bid*nb11,
+ dst + bid*ne0,
ne00,
ne01,
ne02,
+ nb00,
+ nb01,
+ nb02,
ne10,
+ ne11,
ne12,
+ ne13,
+ nb10,
+ nb11,
+ nb12,
ne0,
ne1,
+ nb1,
r2,
r3,
+ shared_values,
tgpig,
+ tiitg,
tiisg,
sgitg);
}
-[[host_name("kernel_mul_mv_id_q5_1_f32")]]
-kernel void kernel_mul_mv_id_q5_1_f32(
+typedef void (kernel_mul_mv_id_t)(
device const char * src0s,
device const char * src1,
device float * dst,
constant uint & r2,
constant uint & r3,
constant int & idx,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
+ uint sgitg[[simdgroup_index_in_threadgroup]]);
+
+template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
+#if QK_K != 64
+template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
+#endif
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q2_K_f32")]]
-kernel void kernel_mul_mv_id_q2_K_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_q2_K_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q3_K_f32")]]
-kernel void kernel_mul_mv_id_q3_K_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_q3_K_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q4_K_f32")]]
-kernel void kernel_mul_mv_id_q4_K_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_q4_K_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q5_K_f32")]]
-kernel void kernel_mul_mv_id_q5_K_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_q5_K_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q6_K_f32")]]
-kernel void kernel_mul_mv_id_q6_K_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_q6_K_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
-kernel void kernel_mul_mv_id_iq2_xxs_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_iq2_xxs_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- shared_values,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
-kernel void kernel_mul_mv_id_iq2_xs_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_iq2_xs_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- shared_values,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
-kernel void kernel_mul_mv_id_iq3_xxs_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_iq3_xxs_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- shared_values,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq3_s_f32")]]
-kernel void kernel_mul_mv_id_iq3_s_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_iq3_s_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- shared_values,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq2_s_f32")]]
-kernel void kernel_mul_mv_id_iq2_s_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_iq2_s_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- shared_values,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
-kernel void kernel_mul_mv_id_iq1_s_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_iq1_s_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq1_m_f32")]]
-kernel void kernel_mul_mv_id_iq1_m_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_iq1_m_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
-kernel void kernel_mul_mv_id_iq4_nl_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- threadgroup float * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_iq4_nl_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- shared_values,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
-kernel void kernel_mul_mv_id_iq4_xs_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- threadgroup float * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
-#if QK_K == 64
- kernel_mul_mv_iq4_nl_f32_impl(
-#else
- kernel_mul_mv_iq4_xs_f32_impl(
-#endif
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- shared_values,
- tgpig,
- tiisg,
- sgitg);
-}