const uint32_t acctype = f32acc ? 4 : 2;
const uint32_t f16vec4 = 8;
+ const uint32_t tmpsh = (Bc / MatBc) * sizeof(float);
+
const uint32_t qstride = hsk_pad / 4 + 2;
const uint32_t Qf = Br * qstride * f16vec4;
const uint32_t slope = Br * acctype;
- const uint32_t total_size = Qf + Psh + sfsh + ksh + slope;
+ const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported);
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
+ float max_mask = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
masksh[c][r] = m;
+ max_mask = max(max_mask, m);
} else {
masksh[c][r] = float(0);
}
}
}
+ // skip the block if the mask is entirely -inf
+ bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
barrier();
+ if (gl_SubgroupInvocationID == 0) {
+ tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
+ }
+ barrier();
+ [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
+ max_mask = max(max_mask, tmpsh[s]);
+ }
+ if (max_mask <= NEG_FLT_MAX_OVER_2) {
+ continue;
+ }
}
float Sf[Br][cols_per_thread];
barrier();
}
+ // prevent race on tmpsh
+ barrier();
+
// reduce across threads
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
return elem;
}
+shared float tmpsh[row_split];
+
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
shared f16vec4 Qf[Br * qstride];
}
}
}
+ // skip the block if the mask is entirely -inf
+ bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
+ barrier();
+ if (gl_SubgroupInvocationID == 0) {
+ tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
+ }
+ barrier();
+ [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
+ max_mask = max(max_mask, tmpsh[s]);
+ }
+ if (max_mask <= NEG_FLT_MAX_OVER_2) {
+ continue;
+ }
}
}
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
+
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
+ // skip the block if the mask is entirely -inf
+ coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
+ if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
+ continue;
+ }
} else {
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
// Don't clamp against nem1 when GQA is enabled
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
+
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
+ // skip the block if the mask is entirely -inf
+ coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
+ if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
+ continue;
+ }
}
}
}