// number of rows/cols for flash attention shader
static constexpr uint32_t flash_attention_num_small_rows = 32;
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
-static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
+
+static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) {
+ if (hsv >= 512) {
+ return 2;
+ } else {
+ return 8;
+ }
+}
// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
// 128 threads split into four subgroups, each subgroup does 1/4
if (small_rows) {
return {scalar_flash_attention_num_small_rows, 64};
} else {
- return {scalar_flash_attention_num_large_rows, 32};
+ return {get_fa_scalar_num_large_rows(hsv), 32};
}
}
// small cols to reduce register count
if (ggml_is_quantized(type) || hsk >= 256) {
- return {64, 32};
+ if (hsk >= 512) {
+ return {32, 32};
+ } else {
+ return {64, 32};
+ }
}
return {64, 64};
}
const uint32_t warps = warptile[0] / warptile[10];
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
- const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0;
+ const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0;
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
s_mmq_wg_denoms_k = { 32, 32, 1 };
// spec constants and tile sizes for quant matmul_id
- l_warptile_mmqid = { 256, 128, 64, 16, 0 };
+ l_warptile_mmqid = { 256, 128, 128, 16, 0 };
m_warptile_mmqid = { 256, 128, 64, 16, 0 };
s_warptile_mmqid = { 256, 128, 64, 16, 0 };
- l_mmqid_wg_denoms = { 128, 64, 1 };
+ l_mmqid_wg_denoms = { 128, 128, 1 };
m_mmqid_wg_denoms = { 128, 64, 1 };
s_mmqid_wg_denoms = { 128, 64, 1 };
// Needs to be kept up to date on shader changes
GGML_UNUSED(hsv);
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
- const uint32_t Br = scalar_flash_attention_num_large_rows;
+ const uint32_t Br = get_fa_scalar_num_large_rows(hsv);
const uint32_t Bc = scalar_flash_attention_Bc;
const uint32_t tmpsh = wg_size * sizeof(float);
case FA_SCALAR:
case FA_COOPMAT1:
// We may switch from coopmat1 to scalar, so use the scalar limit for both
- max_gqa = scalar_flash_attention_num_large_rows;
+ max_gqa = get_fa_scalar_num_large_rows(HSV);
break;
case FA_COOPMAT2:
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_shader_subgroup_basic : enable
+#extension GL_KHR_shader_subgroup_ballot : enable
#endif
#ifdef MUL_MAT_ID
#ifdef MUL_MAT_ID
shared u16vec2 row_ids[4096];
+uint _ne1;
+#ifdef COOPMAT
+shared uint _ne1_sh;
+#endif
#endif // MUL_MAT_ID
#define NUM_WARPS (BLOCK_SIZE / WARP)
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
#ifdef MUL_MAT_ID
- uint _ne1 = 0;
+#ifdef COOPMAT
+ // Spread the search across all elements in the first subgroup
+ if (gl_SubgroupID == 0) {
+ _ne1 = 0;
+ uint num_elements = p.nei1 * p.nei0;
+
+ uint ids[16];
+ uint iter = 0;
+
+ for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
+ // prefetch up to 16 elements
+ if (iter == 0) {
+ [[unroll]] for (uint k = 0; k < 16; ++k) {
+ uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
+ bool in_range = i < num_elements;
+ uint ii1 = i / p.nei0;
+ uint ii0 = i % p.nei0;
+ ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
+ }
+ }
+ uint i = j + gl_SubgroupInvocationID;
+ bool in_range = i < num_elements;
+ uint ii1 = i / p.nei0;
+ uint ii0 = i % p.nei0;
+ uint id = ids[iter++];
+ uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
+ uint idx = subgroupBallotExclusiveBitCount(ballot);
+ if (in_range && id == expert_idx) {
+ row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
+ }
+ _ne1 += subgroupBallotBitCount(ballot);
+ iter &= 15;
+ }
+ _ne1_sh = _ne1;
+ }
+
+ barrier();
+
+ _ne1 = _ne1_sh;
+#else
+ _ne1 = 0;
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
}
barrier();
+#endif
// Workgroup has no work
if (ic * BN >= _ne1) return;
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;
- for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) {
+ uint ids[16];
+ uint iter = 0;
+
+ for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
+ // prefetch up to 16 elements
+ if (iter == 0) {
+ [[unroll]] for (uint k = 0; k < 16; ++k) {
+ uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
+ bool in_range = i < num_elements;
+ uint ii1 = i / p.nei0;
+ uint ii0 = i % p.nei0;
+ ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
+ }
+ }
+ uint i = j + gl_SubgroupInvocationID;
bool in_range = i < num_elements;
- uint ii0 = i % p.nei0;
uint ii1 = i / p.nei0;
- uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
+ uint ii0 = i % p.nei0;
+ uint id = ids[iter++];
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
uint idx = subgroupBallotExclusiveBitCount(ballot);
if (in_range && id == expert_idx) {
row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
}
_ne1 += subgroupBallotBitCount(ballot);
+ iter &= 15;
}
_ne1_sh = _ne1;
}
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
}
- coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
- coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
+ if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
- coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
#ifdef MUL_MAT_ID
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
#else
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
#endif
- sum = coopMatMulAdd(mat_a, mat_b, sum);
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ } else {
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
+#ifdef MUL_MAT_ID
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
+#else
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
+#endif
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ }
}
// Convert from ACC_TYPE to D_TYPE