block_q4_K_packed128 block;
};
+#if defined(IS_MUL_MM2)
+
+// For Q4_K and Q5_K in the mat-mul shader, we decode a tile's worth of scales
+// into shared memory and then process the whole tile using those scales.
+// There is a fetch function that loads into private variables and then a store
+// function that stores into shared memory.
+// Q4_K and Q5_K have the same encoding of scales, so everything is shared except
+// the part that fetches from the structure (which has a different block layout).
+#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
+const uint shAscales_stride = (BM + 2);
+// 1 scale per 32 elements -> 8 scales per block, per row
+shared vec2 shAscales[8 * shAscales_stride];
+uvec4 row_v;
+#endif
+
+#if defined(DATA_A_Q4_K)
+layout (binding = 0) readonly buffer A_Q4_K_128 {block_q4_K_packed128 data_a_q4_k_packed128[];};
+
+void fetch_scalesQ4_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
+{
+ uint tids_per_row = BLOCK_SIZE / BM;
+ uint is_per_tid = 8 / tids_per_row;
+ uint is_start = is_per_tid * (tid % tids_per_row);
+ uint tid_row = tid / tids_per_row;
+
+ uint row = ir_BM + tid_row;
+ uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
+ if (in_bounds || row < p.M) {
+ row_v = data_a_q4_k_packed128[block_index].q4k[0];
+ }
+}
+#endif
+#if defined(DATA_A_Q5_K)
+layout (binding = 0) readonly buffer A_Q5_K_128 {block_q5_K_packed128 data_a_q5_k_packed128[];};
+
+void fetch_scalesQ5_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
+{
+ uint tids_per_row = BLOCK_SIZE / BM;
+ uint is_per_tid = 8 / tids_per_row;
+ uint is_start = is_per_tid * (tid % tids_per_row);
+ uint tid_row = tid / tids_per_row;
+
+ uint row = ir_BM + tid_row;
+ uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
+ if (in_bounds || row < p.M) {
+ row_v = data_a_q5_k_packed128[block_index].q5k[0];
+ }
+}
+#endif
+
+#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
+void store_scalesQ4_K(uint tid)
+{
+ barrier();
+
+ uint tids_per_row = BLOCK_SIZE / BM;
+ uint is_per_tid = 8 / tids_per_row;
+ uint is_start = is_per_tid * (tid % tids_per_row);
+ uint tid_row = tid / tids_per_row;
+
+ [[unroll]] for (uint idx = 0; idx < is_per_tid; ++idx) {
+ uint is = idx + is_start;
+ uvec4 v = row_v;
+ const vec2 loadd = vec2(unpackFloat2x16(v.x));
+
+ uint32_t sc;
+ uint32_t mbyte;
+
+ uint32_t scale0 = v.y;
+ uint32_t scale4 = v.z;
+ uint32_t scale8 = v.w;
+
+ uint32_t sc_lo = scale0;
+ uint32_t mb_lo = scale4;
+ uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
+ uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
+
+ sc = is < 4 ? sc_lo : sc_hi;
+ mbyte = is < 4 ? mb_lo : mb_hi;
+ sc = sc >> (8 * (is & 3));
+ mbyte = mbyte >> (8 * (is & 3));
+ sc &= 0x3F;
+ mbyte &= 0x3F;
+
+ const float d = loadd.x * float(sc);
+ const float m = loadd.y * float(mbyte);
+ shAscales[is * shAscales_stride + tid_row] = vec2(d,m);
+ }
+
+ barrier();
+}
+#endif
+
+#endif
+
float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
const uint b = (idx & 0x20) >> 5; // 0,1
const uint is = (idx & 0xE0) >> 5; // 0..7
+#if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K)
+ vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
+ float d = v.x;
+ float m = v.y;
+#else
uvec4 v = bl128.block.q4k[0];
-
const vec2 loadd = vec2(unpackFloat2x16(v.x));
uint32_t sc;
const float d = loadd.x * float(sc);
const float m = loadd.y * float(mbyte);
+#endif
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
const uint b = (idx & 0x20) >> 5; // 0,1
const uint is = (idx & 0xE0) >> 5; // 0..7
+#if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K)
+ vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
+ float d = v.x;
+ float m = v.y;
+#else
uvec4 v = bl128.block.q5k[0];
const f16vec2 loadd = unpackFloat2x16(v.x);
const float16_t d = loadd.x * float16_t(sc);
const float16_t m = loadd.y * float16_t(mbyte);
+#endif
uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
qh = ((qh >> is) & 0x101) << 4;
qs = (qs >> (b * 4)) & 0x0F0F;
qs = unpack8(qs | qh)[idx & 1];
- float16_t ret = d * (float16_t(qs)) - m;
+ float ret = d * float(qs) - m;
- return ret;
+ return float16_t(ret);
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K {
#define dequantFuncA dequantFuncQ3_K
#elif defined(DATA_A_Q4_K)
#define dequantFuncA dequantFuncQ4_K
+#define fetch_scales fetch_scalesQ4_K
+#define store_scales store_scalesQ4_K
#elif defined(DATA_A_Q5_K)
#define dequantFuncA dequantFuncQ5_K
+#define fetch_scales fetch_scalesQ5_K
+#define store_scales store_scalesQ4_K
#elif defined(DATA_A_Q6_K)
#define dequantFuncA dequantFuncQ6_K
#elif defined(DATA_A_IQ1_S)
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+#define IS_MUL_MM2 1
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 256;
layout (constant_id = 1) const uint BM = 64;
layout (constant_id = 2) const uint BN = 64;
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
#define DECODEFUNCA
#endif
+#if !defined(fetch_scales)
+#define fetch_scales(a, b, c, d, e, f)
+#endif
+#if !defined(store_scales)
+#define store_scales(a)
+#endif
+
#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};
init_iq_shmem(gl_WorkGroupSize);
#endif
+ const uint tid = gl_LocalInvocationIndex;
+
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.z;
#else
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
#if !defined(MUL_MAT_ID)
+
+ const uint START_ALIGN_K = 256;
+ // For Qi_K (block size 256), unroll whole 256 element tiles.
+ // For legacy quants (block size 32), unroll 8x.
+ const uint UNROLL_K = (QUANT_K == 256) ? 256 : (BK * 8);
+ const uint unroll_count = UNROLL_K / BK;
+
// Detect a fast path where all loads are entirely in bounds and no clamping is required
- if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % BK) == 0 && (end_k % BK) == 0 &&
+ if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % START_ALIGN_K) == 0 && (end_k % BK) == 0 &&
#if QUANT_K == 1
(stride_a % 8) == 0 &&
#endif
- (stride_b % 8) == 0 && (start_k % 8) == 0) {
+ (stride_b % 8) == 0) {
// Hint to the compiler that values are aligned (want 16B alignment)
- start_k &= ~7;
+ start_k &= ~(START_ALIGN_K-1);
stride_b &= ~7;
#if QUANT_K == 1
stride_a &= ~7;
tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1);
tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
- uint k_iters = (end_k - start_k + BK - 1) / BK;
+ uint k_iters = (end_k - start_k) / UNROLL_K;
+ uint block_k = start_k;
+
+ // fetch scale values for a tile of quants. These will be copied into shared memory.
+ // The fetches and stores are pipelined to hide the latency.
+ fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, true);
+
if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
- for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
+ for (uint i = 0; i < k_iters; ++i) {
+
+ store_scales(tid);
+ if (block_k + UNROLL_K < end_k) {
+ fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
+ }
+ // Manually partial unroll
+ [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
+ coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ block_k += BK;
+ }
+ }
+ // Do any remaining iterations that were not unrolled
+ if (block_k < end_k) {
+ store_scales(tid);
+ }
+ while (block_k < end_k) {
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
sum = coopMatMulAdd(mat_a, mat_b, sum);
+ block_k += BK;
}
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
return;
} else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
- for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
+ for (uint i = 0; i < k_iters; ++i) {
+
+ store_scales(tid);
+ if (block_k + UNROLL_K < end_k) {
+ fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
+ }
+
+ // Manually partial unroll
+ [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
+ coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ block_k += BK;
+ }
+ }
+ // Do any remaining iterations that were not unrolled
+ if (block_k < end_k) {
+ store_scales(tid);
+ }
+ while (block_k < end_k) {
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
sum = coopMatMulAdd(mat_a, mat_b, sum);
+ block_k += BK;
}
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
return;
} else {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
- for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
+ for (uint i = 0; i < k_iters; ++i) {
+
+ store_scales(tid);
+ if (block_k + UNROLL_K < end_k) {
+ fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
+ }
+
+ // Manually partial unroll
+ [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
+ coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ block_k += BK;
+ }
+ }
+ // Do any remaining iterations that were not unrolled
+ if (block_k < end_k) {
+ store_scales(tid);
+ }
+ while (block_k < end_k) {
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
sum = coopMatMulAdd(mat_a, mat_b, sum);
+ block_k += BK;
}
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
+ uint k_iters = (end_k - start_k + BK - 1) / BK;
+
+ fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false);
+
[[dont_unroll]]
- for (uint block_k = start_k; block_k < end_k; block_k += BK) {
+ for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
+
+ store_scales(tid);
+ if (block_k + BK < end_k) {
+ fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
+ }
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
- // Clamping is expensive, so detect different code paths for each combination
- // of A and B needing clamping.
- bool unclampedA = (ir + 1) * BM <= p.M && block_k + BK <= end_k && (block_k % 8) == 0;
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
#ifdef MUL_MAT_ID
- bool unclampedB = true;
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
#else
- bool unclampedB = (ic + 1) * BN <= p.padded_N && block_k + BK <= end_k && (block_k % 8) == 0;
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
#endif
- if (unclampedA && unclampedB) {
- coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
-#ifdef MUL_MAT_ID
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
-#else
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
-#endif
- sum = coopMatMulAdd(mat_a, mat_b, sum);
- } else if (unclampedA && !unclampedB) {
- coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
-
- sum = coopMatMulAdd(mat_a, mat_b, sum);
- } else if (!unclampedA && unclampedB) {
- coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
-#ifdef MUL_MAT_ID
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
-#else
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
-#endif
- sum = coopMatMulAdd(mat_a, mat_b, sum);
- } else if (!unclampedA && !unclampedB) {
- coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
- sum = coopMatMulAdd(mat_a, mat_b, sum);
- }
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
}
// Convert from ACC_TYPE to D_TYPE