#endif
#elif defined(DATA_A_Q4_0)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
- const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
const uint ib = idx / 4;
const uint iqs = idx & 0x03;
buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw);
#elif defined(DATA_A_Q4_1)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
- const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
const uint ib = idx / 4;
const uint iqs = idx & 0x03;
- const float d = float(data_a_packed16[ib].d);
- const float m = float(data_a_packed16[ib].m);
- const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
- const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m;
- const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m;
+ const vec2 dm = vec2(data_a_packed32[ib].dm);
+ const uint vui = data_a_packed32[ib].qs[iqs];
+ const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * dm.x + dm.y;
+ const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * dm.x + dm.y;
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw);
buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw);
#elif defined(DATA_A_Q5_0)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
- const uint buf_idx = col * SHMEM_STRIDE + row;
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
const uint ib = idx / 8;
const uint iqs = idx & 0x07;
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
#elif defined(DATA_A_Q5_1)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
- const uint buf_idx = col * SHMEM_STRIDE + row;
-
- const uint ib = idx / 8;
- const uint iqs = idx & 0x07;
-
- const float d = float(data_a_packed16[ib].d);
- const float m = float(data_a_packed16[ib].m);
- const uint uint_qh = data_a_packed16[ib].qh;
- const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
- const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
- const uint vui = uint(data_a_packed16[ib].qs[iqs]);
- const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;
+ const uint ib = idx / 4;
+ const uint iqs = idx & 0x03;
- buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz);
- buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
+ const vec2 dm = vec2(data_a_packed32[ib].dm);
+ const uint uint_qh = data_a_packed32[ib].qh;
+ const uvec2 qh0 = uvec2(((uint_qh >> 4*iqs) << 4) & 0x10, (uint_qh >> (4*iqs + 12)) & 0x10);
+ const uvec2 qh1 = uvec2(((uint_qh >> (4*iqs + 1)) << 4) & 0x10, (uint_qh >> (4*iqs + 13)) & 0x10);
+ const uvec2 qh2 = uvec2(((uint_qh >> (4*iqs + 2)) << 4) & 0x10, (uint_qh >> (4*iqs + 14)) & 0x10);
+ const uvec2 qh3 = uvec2(((uint_qh >> (4*iqs + 3)) << 4) & 0x10, (uint_qh >> (4*iqs + 15)) & 0x10);
+
+ const uint vui = data_a_packed32[ib].qs[iqs];
+ const vec4 v0 = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * dm.x + dm.y;
+ const vec4 v1 = vec4(((vui >> 16) & 0xF) | qh2.x, ((vui >> 20) & 0xF) | qh2.y, ((vui >> 24) & 0xF) | qh3.x, ((vui >> 28) & 0xF) | qh3.y) * dm.x + dm.y;
+
+ buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xz);
+ buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v1.xz);
+ buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v0.yw);
+ buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.yw);
#elif defined(DATA_A_Q8_0)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
- const uint ib = idx / 128; // 2 values per idx
- const uint iqs = idx % 128; // 0..127
+ const uint ib = idx / 64; // 4 values per idx
+ const uint iqs = (idx % 64) * 2; // 0,2,4..126
const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15
const uint scalesi = iqs / 8; // 0..15
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
- const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi]));
+ const vec4 qs = vec4(unpack8((data_a_packed32[ib].qs[qsi / 2] >> qsshift) & 0x03030303));
const uint scales = data_a[ib].scales[scalesi];
const vec2 dm = vec2(data_a[ib].dm);
- const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
+ const vec4 v = dm.x * float(scales & 0xF) * qs - dm.y * float(scales >> 4);
- buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
+ buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
+ buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
#elif defined(DATA_A_Q3_K)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
- const uint ib = idx / 128; // 2 values per idx
- const uint iqs = idx % 128; // 0..127
+ const uint ib = idx / 64; // 4 values per idx
+ const uint iqs = (idx % 64) * 2; // 0,2,4..126
const uint n = iqs / 32; // 0,1,2,3
const uint b = (iqs % 32) / 16; // 0,1
const float d = loadd.x * sc;
const float m = -loadd.y * mbyte;
- const vec2 q = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F).xy);
+ const vec4 q = vec4(unpack8((data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F));
- buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m),
- fma(d, q.y, m));
+ buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m));
+ buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m));
#elif defined(DATA_A_Q5_K)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
- const uint ib = idx / 128; // 2 values per idx
- const uint iqs = idx % 128; // 0..127
+ const uint ib = idx / 64; // 4 values per idx
+ const uint iqs = (idx % 64) * 2; // 0,2,4..126
const uint n = iqs / 32; // 0,1,2,3
const uint b = (iqs % 32) / 16; // 0,1
const float d = loadd.x * sc;
const float m = -loadd.y * mbyte;
- const uint qs = (uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F;
- const uint qh = ((uint(data_a_packed16[ib].qh[qhi / 2]) >> (iqs / 16)) & 0x0101) << 4;
- const vec2 q = vec2(unpack8(qs | qh).xy);
+ const uint qs = (data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F;
+ const uint qh = ((data_a_packed32[ib].qh[qhi / 4] >> (iqs / 16)) & 0x01010101) << 4;
+ const vec4 q = vec4(unpack8(qs | qh));
- buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m),
- fma(d, q.y, m));
+ buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m));
+ buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m));
#elif defined(DATA_A_Q6_K)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
#elif defined(DATA_A_IQ4_NL)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
- const uint buf_idx = col * SHMEM_STRIDE + row;
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
const uint ib = idx / 8;
const uint iqs = idx & 0x07;
kvalues_iq4nl[vui >> 12]);
#elif defined(DATA_A_MXFP4)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
- const uint buf_idx = col * SHMEM_STRIDE + row;
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
const uint ib = idx / 8;
const uint iqs = (idx & 0x07) * 2;