[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+ if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
+ continue;
+ }
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) {
- masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
+ if (!KV_bounds_check || j * Bc + c < KV) {
+ masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
+ } else {
+ masksh[c][r] = float(0);
+ }
}
}
barrier();
float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
- rowmaxf[r] = Sf[r][0];
+ rowmaxf[r] = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+ if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
+ continue;
+ }
rowmaxf[r] = max(rowmaxf[r], Sf[r][c]);
}
Moldf[r] = Mf[r];
// Compute sum across row of P
rowsumf[r] = 0.0;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+ if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
+ continue;
+ }
rowsumf[r] += Pf[r][c];
}
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+ if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
+ continue;
+ }
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
if (c < Bc && d < HSK / 4) {
+ f16vec4 K_Tf = f16vec4(0);
+ if (!KV_bounds_check || j * Bc + c < KV) {
#if BLOCK_SIZE > 1
- uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
- uint ib = coord / BLOCK_SIZE;
- uint iqs = (coord % BLOCK_SIZE);
- f16vec4 K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
+ uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
+ uint ib = coord / BLOCK_SIZE;
+ uint iqs = (coord % BLOCK_SIZE);
+ K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
#else
- f16vec4 K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
+ K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
#endif
+ }
ksh[c * kshstride + d] = K_Tf;
}
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
- sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
+ if (!KV_bounds_check || j * Bc + c < KV) {
+ sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
+ }
}
}
barrier();
float eMf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
- float rowmaxf = sfsh[tile_row(r) + (0 * cols_per_iter + col_tid) * sfshstride];
+ float rowmaxf = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+ if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
+ continue;
+ }
rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride]));
}
float Moldf = Mf[r];
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+ if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
+ continue;
+ }
float Pf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);