shared float minsh[NUM_SUBGROUPS];
shared float maxsh[NUM_SUBGROUPS];
+float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF);
+
+void loadvec4(inout uint result, const uint i0, const uint i1, const uint i2, const uint i3, const bool need_bounds_check) {
+ const uint tid = gl_LocalInvocationIndex;
+
+ [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
+ float min_v = FLT_MAX_OVER_2;
+ float max_v = -FLT_MAX_OVER_2;
+ [[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) {
+ uint j0 = (i + tid) % (Bc / 4);
+ uint j1 = (i + tid) / (Bc / 4);
+
+ j0 *= 4;
+ j0 += (i0 * 16 + block_x) * Bc;
+ j1 += i1 * Br;
+
+ if (!need_bounds_check || j0 + 3 < nem0) {
+ vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]);
+ [[unroll]] for (int c = 0; c < 4; ++c) {
+ min_v = min(min_v, f[c]);
+ max_v = max(max_v, f[c]);
+ }
+ } else {
+ [[unroll]] for (int c = 0; c < 4; ++c) {
+ if (j0 + c < nem0) {
+ float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]);
+ min_v = min(min_v, f);
+ max_v = max(max_v, f);
+ }
+ }
+ }
+ }
+ min_v = subgroupMin(min_v);
+ max_v = subgroupMax(max_v);
+ if (gl_SubgroupInvocationID == 0) {
+ minsh[gl_SubgroupID] = min_v;
+ maxsh[gl_SubgroupID] = max_v;
+ }
+ barrier();
+ if (tid == 0) {
+ [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
+ min_v = min(min_v, minsh[i]);
+ max_v = max(max_v, maxsh[i]);
+ }
+ if (max_v <= -FLT_MAX_OVER_2) {
+ result |= 1 << (2*block_x);
+ }
+ if (min_v == 0.0f && max_v == 0.0f) {
+ result |= 2 << (2*block_x);
+ }
+ }
+ barrier();
+ }
+}
+
// For each Br x Bc block of the mask (input) buffer, read all values and check
// if it's all -inf or all zero. Write out a two-bit code indicating which it is
// (or zero for neither). Each workgroup processes 16 tiles and writes out a
const uint i2 = gl_WorkGroupID.z % nem2;
const uint i3 = gl_WorkGroupID.z / nem2;
- float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF);
-
uint result = 0;
// Fast path for fully in-bounds blocks where we can do f16vec4 loads
if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 &&
((Br * Bc) % (BLOCK_SIZE * 4)) == 0) {
- [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
- float min_v = FLT_MAX_OVER_2;
- float max_v = -FLT_MAX_OVER_2;
- [[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) {
- uint j0 = (i + tid) % (Bc / 4);
- uint j1 = (i + tid) / (Bc / 4);
-
- j0 *= 4;
- j0 += (i0 * 16 + block_x) * Bc;
- j1 += i1 * Br;
-
- vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]);
- [[unroll]] for (int c = 0; c < 4; ++c) {
- min_v = min(min_v, f[c]);
- max_v = max(max_v, f[c]);
- }
- }
- min_v = subgroupMin(min_v);
- max_v = subgroupMax(max_v);
- if (gl_SubgroupInvocationID == 0) {
- minsh[gl_SubgroupID] = min_v;
- maxsh[gl_SubgroupID] = max_v;
- }
- barrier();
- if (tid == 0) {
- [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
- min_v = min(min_v, minsh[i]);
- max_v = max(max_v, maxsh[i]);
- }
- if (max_v <= -FLT_MAX_OVER_2) {
- result |= 1 << (2*block_x);
- }
- if (min_v == 0.0f && max_v == 0.0f) {
- result |= 2 << (2*block_x);
- }
- }
- barrier();
+ if ((i0 + 1) * 16 * Bc <= nem0) {
+ loadvec4(result, i0, i1, i2, i3, false);
+ } else {
+ loadvec4(result, i0, i1, i2, i3, true);
}
} else {
[[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {