layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+#if defined(A_TYPE_PACKED16)
+layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
+#endif
+#if defined(A_TYPE_PACKED32)
+layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
+#endif
+
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#endif
#elif defined(DATA_A_Q4_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
-
- const uint ib = idx / 16;
- const uint iqs = idx & 0xF;
-
- const float d = float(data_a[ib].d);
- const uint vui = uint(data_a[ib].qs[iqs]);
- const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
-
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
- buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
+
+ const uint ib = idx / 4;
+ const uint iqs = idx & 0x03;
+
+ const float d = float(data_a_packed16[ib].d);
+ const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
+ const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d;
+ const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d;
+
+ buf_a[buf_idx ] = FLOAT_TYPE(v0.x);
+ buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
+ buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
+ buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
+ buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
+ buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
+ buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
+ buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
#elif defined(DATA_A_Q4_1)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
-
- const uint ib = idx / 16;
- const uint iqs = idx & 0xF;
-
- const float d = float(data_a[ib].d);
- const float m = float(data_a[ib].m);
- const uint vui = uint(data_a[ib].qs[iqs]);
- const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m;
-
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
- buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
+
+ const uint ib = idx / 4;
+ const uint iqs = idx & 0x03;
+
+ const float d = float(data_a_packed16[ib].d);
+ const float m = float(data_a_packed16[ib].m);
+ const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
+ const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m;
+ const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m;
+
+ buf_a[buf_idx ] = FLOAT_TYPE(v0.x);
+ buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
+ buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
+ buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
+ buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
+ buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
+ buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
+ buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
#elif defined(DATA_A_Q5_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
- const uint ib = idx / 16;
- const uint iqs = idx & 0xF;
+ const uint ib = idx / 8;
+ const uint iqs = idx & 0x07;
- const float d = float(data_a[ib].d);
- const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
- const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
- const uint vui = uint(data_a[ib].qs[iqs]);
- const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
+ const float d = float(data_a_packed16[ib].d);
+ const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]);
+ const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
+ const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
+
+ const uint vui = uint(data_a_packed16[ib].qs[iqs]);
+ const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d;
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
+ buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
+ buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
#elif defined(DATA_A_Q5_1)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
- const uint ib = idx / 16;
- const uint iqs = idx & 0xF;
+ const uint ib = idx / 8;
+ const uint iqs = idx & 0x07;
- const float d = float(data_a[ib].d);
- const float m = float(data_a[ib].m);
- const uint uint_qh = data_a[ib].qh;
- const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
- const uint vui = uint(data_a[ib].qs[iqs]);
- const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
+ const float d = float(data_a_packed16[ib].d);
+ const float m = float(data_a_packed16[ib].m);
+ const uint uint_qh = data_a_packed16[ib].qh;
+ const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
+ const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
+
+ const uint vui = uint(data_a_packed16[ib].qs[iqs]);
+ const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
+ buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
+ buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
#elif defined(DATA_A_Q8_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
- const uint ib = idx / 16;
- const uint iqs = (idx & 0xF) * 2;
+ const uint ib = idx / 8;
+ const uint iqs = idx & 0x07;
- const float d = float(data_a[ib].d);
- const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d;
+ const float d = float(data_a_packed16[ib].d);
+ const i8vec2 v0 = unpack8(data_a_packed16[ib].qs[2*iqs]);
+ const i8vec2 v1 = unpack8(data_a_packed16[ib].qs[2*iqs + 1]);
+ const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+ buf_a[buf_idx + 2] = FLOAT_TYPE(v.z);
+ buf_a[buf_idx + 3] = FLOAT_TYPE(v.w);
#elif defined(DATA_A_Q2_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_IQ4_NL)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
- const uint ib = idx / 16;
- const uint iqs = idx & 0xF;
+ const uint ib = idx / 8;
+ const uint iqs = idx & 0x07;
- const float d = float(data_a[ib].d);
- const uint vui = uint(data_a[ib].qs[iqs]);
- const vec2 v = vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d;
+ const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d);
+ const uint vui = uint(data_a_packed16[ib].qs[iqs]);
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
- buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
+ buf_a[buf_idx ] = FLOAT_TYPE(kvalues_iq4nl[vui & 0xF]) * d;
+ buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d;
+ buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d;
+ buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d;
#endif
}
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {