kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
-template<int nr0, typename args_t>
+template<int NR0, typename args_t>
void kernel_mul_mv_iq4_nl_f32_impl(
args_t args,
device const char * src0,
const short NSG = FC_mul_mv_nsg;
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
- const int nb = args.ne00/QK4_NL;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * NSG + sgitg) * nr0;
+ const int first_row = (r0 * NSG + sgitg) * NR0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
+ const int nb = args.ne00/QK4_NL;
+ const int ns01 = args.nb01/args.nb00;
+
const short ix = tiisg/2; // 0...15
const short it = tiisg%2; // 0 or 1
threadgroup_barrier(mem_flags::mem_threadgroup);
float4 yl[4];
- float sumf[nr0]={0.f};
+ float sumf[NR0]={0.f};
- device const float * yb = y + ix * QK4_NL + it * 8;
+ device const float * yb = y + ix*QK4_NL + it*8;
uint32_t aux32[2];
thread const uint8_t * q8 = (thread const uint8_t *)aux32;
float4 qf1, qf2;
- for (int ib = ix; ib < nb; ib += 16) {
+ // [TAG_MUL_MV_WEIRD]
+ for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
device const float4 * y4 = (device const float4 *)yb;
yl[0] = y4[0];
yl[1] = y4[4];
yl[2] = y4[1];
yl[3] = y4[5];
- for (short row = 0; row < nr0; row++) {
- device const block_iq4_nl & xb = x[row*nb + ib];
+ for (short row = 0; row < NR0; row++) {
+ device const block_iq4_nl & xb = x[row*ns01 + ib];
device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
float4 acc1 = {0.f}, acc2 = {0.f};
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+ for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = sum_all;
kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
-template<int nr0, typename args_t>
+template<int NR0, typename args_t>
void kernel_mul_mv_iq4_xs_f32_impl(
args_t args,
device const char * src0,
const short NSG = FC_mul_mv_nsg;
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
- const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * NSG + sgitg) * nr0;
+ const int first_row = (r0 * NSG + sgitg) * NR0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
+ const int nb = args.ne00/QK_K;
+ const int ns01 = args.nb01/args.nb00;
+
const short ix = tiisg/16; // 0 or 1
const short it = tiisg%16; // 0...15
const short ib = it/2;
threadgroup_barrier(mem_flags::mem_threadgroup);
float4 yl[4];
- float sumf[nr0]={0.f};
+ float sumf[NR0]={0.f};
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
float4 qf1, qf2;
- for (int ibl = ix; ibl < nb; ibl += 2) {
+ // [TAG_MUL_MV_WEIRD]
+ for (int ibl = ix; ibl < nb && ibl < ns01; ibl += 2) {
device const float4 * y4 = (device const float4 *)yb;
yl[0] = y4[0];
yl[1] = y4[4];
yl[2] = y4[1];
yl[3] = y4[5];
- for (short row = 0; row < nr0; ++row) {
- device const block_iq4_xs & xb = x[row*nb + ibl];
+ for (short row = 0; row < NR0; ++row) {
+ device const block_iq4_xs & xb = x[row*ns01 + ibl];
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
float4 acc1 = {0.f}, acc2 = {0.f};
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+ for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = sum_all;
kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
-template<int nr0, typename args_t>
+template<int NR0, typename args_t>
void kernel_mul_mv_mxfp4_f32_impl(
args_t args,
device const char * src0,
const short NSG = FC_mul_mv_nsg;
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
- const int nb = args.ne00/QK_MXFP4;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * NSG + sgitg) * nr0;
+ const int first_row = (r0 * NSG + sgitg) * NR0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
+ const int nb = args.ne00/QK_MXFP4;
+ const int ns01 = args.nb01/args.nb00; // this can be larger than nb for permuted src0 tensors
+
const short ix = tiisg/2; // 0...15
const short it = tiisg%2; // 0 or 1
threadgroup_barrier(mem_flags::mem_threadgroup);
float4 yl[4];
- float sumf[nr0]={0.f};
+ float sumf[NR0]={0.f};
- device const float * yb = y + ix * QK_MXFP4 + it * 8;
+ device const float * yb = y + ix*QK_MXFP4 + it*8;
+
+ // note: just the check `ib < nb` is enough, but adding the redundant `&& ib < ns01` check makes the kernel a bit faster
+ // no idea why that is - needs some deeper investigation [TAG_MUL_MV_WEIRD]
+ for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
+ device const float4 * y4 = (device const float4 *) yb;
- for (int ib = ix; ib < nb; ib += 16) {
- device const float4 * y4 = (device const float4 *)yb;
yl[0] = y4[0];
yl[1] = y4[4];
yl[2] = y4[1];
yl[3] = y4[5];
-#pragma unroll(nr0)
- for (short row = 0; row < nr0; row++) {
- device const block_mxfp4 & xb = x[row*nb + ib];
+ FOR_UNROLL (short row = 0; row < NR0; row++) {
+ device const block_mxfp4 & xb = x[row*ns01 + ib];
device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it);
float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]);
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+ for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = sum_all;