};
struct vk_fa_pipeline_state {
- vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc)
- : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc) {}
+ vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, bool use_mask_opt)
+ : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), use_mask_opt(use_mask_opt) {}
uint32_t HSK, HSV;
bool small_rows, small_cache;
FaCodePath path;
bool aligned;
bool f32acc;
+ bool use_mask_opt;
bool operator<(const vk_fa_pipeline_state &b) const {
- return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc) <
- std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc);
+ return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, use_mask_opt) <
+ std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.use_mask_opt);
}
};
std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT];
+ std::map<std::pair<uint32_t, uint32_t>, vk_pipeline> pipeline_fa_mask_opt;
+
vk_pipeline pipeline_flash_attn_split_k_reduce;
vk_pipeline pipeline_count_experts;
uint32_t sinks;
};
+struct vk_op_flash_attn_mask_opt_push_constants {
+ uint32_t nem0;
+ uint32_t nem1;
+ uint32_t nem2;
+ uint32_t nbm1;
+ uint32_t nbm2;
+ uint32_t nbm3;
+ uint32_t nbd1;
+ uint32_t nbd2;
+ uint32_t nbd3;
+};
+
// Allow pre-recording command buffers
struct vk_staging_memcpy {
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
" k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " <<
" v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " <<
" m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")";
+ *n_flops = 2ull * q->ne[1] * q->ne[2] * (k->ne[0] + v->ne[0]) * k->ne[1] * q->ne[3];
return name.str();
}
if (node->op == GGML_OP_TOP_K) {
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1};
};
- auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::vector<uint32_t> {
+ auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, bool use_mask_opt) -> std::vector<uint32_t> {
// For large number of rows, 128 invocations seems to work best.
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
// can't use 256 for D==80.
// AMD prefers loading K directly from global memory
const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0;
- return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem};
+ return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, use_mask_opt};
};
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
FaCodePath path = fa.first.path; \
bool aligned = fa.first.aligned; \
bool f32acc = fa.first.f32acc; \
+ bool use_mask_opt = fa.first.use_mask_opt; \
if (path == FAPATH) { \
if (aligned) { \
if (f32acc) { \
- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,use_mask_opt), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
} else { \
- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,use_mask_opt), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
} \
} else { \
if (f32acc) { \
- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,use_mask_opt), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
} else { \
- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,use_mask_opt), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
} \
} \
} \
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
+ for (auto &it : device->pipeline_fa_mask_opt) {
+ auto BrBc = it.first;
+ ggml_vk_create_pipeline(device, it.second, "fa_mask_opt", fa_mask_opt_len, fa_mask_opt_data, "main", 2, sizeof(vk_op_flash_attn_mask_opt_push_constants), {1, 1, 1}, {128, 128 / device->subgroup_size, BrBc.first, BrBc.second}, 1, true, true, device->subgroup_size);
+ }
+
if (device->subgroup_clustered && device->subgroup_require_full_support) {
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
} else {
const uint32_t acctype = f32acc ? 4 : 2;
const uint32_t f16vec4 = 8;
- const uint32_t tmpsh = (Bc / MatBc) * sizeof(float);
-
const uint32_t qstride = hsk_pad / 4 + 2;
const uint32_t Qf = Br * qstride * f16vec4;
const uint32_t slope = Br * acctype;
- const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope;
+ const uint32_t total_size = Qf + Psh + sfsh + ksh + slope;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported);
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
+ const uint32_t nem0 = mask ? mask->ne[0] : 0;
const uint32_t nem1 = mask ? mask->ne[1] : 0;
const uint32_t nem2 = mask ? mask->ne[2] : 0;
const uint32_t nem3 = mask ? mask->ne[3] : 0;
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
- vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc);
+ // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
+ bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768;
+
+ vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, use_mask_opt);
vk_pipeline pipeline = nullptr;
ggml_vk_preallocate_buffers(ctx, subctx);
}
- {
- // Request descriptor sets
- if (split_k > 1) {
- ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
+ auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, small_rows, small_cache);
+ const uint32_t Br = rows_cols[0];
+ const uint32_t Bc = rows_cols[1];
+
+ const uint32_t mask_opt_num_dwords = CEIL_DIV(nem0, 16 * Bc);
+ const uint64_t mask_opt_size = sizeof(uint32_t) * mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2 * nem3;
+
+ vk_pipeline pipeline_fa_mask_opt = nullptr;
+ if (use_mask_opt) {
+ std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
+ auto &pipelines = ctx->device->pipeline_fa_mask_opt;
+ auto it = pipelines.find({Br, Bc});
+ if (it != pipelines.end()) {
+ pipeline_fa_mask_opt = it->second;
+ } else {
+ pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>();
+ }
+ assert(pipeline_fa_mask_opt);
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline_fa_mask_opt, 1);
+
+ if (ctx->prealloc_size_y < mask_opt_size) {
+ ctx->prealloc_size_y = mask_opt_size;
+ ggml_vk_preallocate_buffers(ctx, subctx);
+ }
+ if (ctx->prealloc_y_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
}
}
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer(ctx, mask) : q_buf;
vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf;
+ vk_subbuffer mask_opt_buf = use_mask_opt ? ggml_vk_subbuffer(ctx, ctx->prealloc_y, 0) : q_buf;
uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2;
+ if (use_mask_opt)
+ {
+ const vk_op_flash_attn_mask_opt_push_constants opt_pc = {
+ nem0,
+ nem1,
+ nem2,
+ (uint32_t)(mask->nb[1] / sizeof(ggml_fp16_t)),
+ (uint32_t)(mask->nb[2] / sizeof(ggml_fp16_t)),
+ (uint32_t)(mask->nb[3] / sizeof(ggml_fp16_t)),
+ mask_opt_num_dwords,
+ mask_opt_num_dwords * CEIL_DIV(nem1, Br),
+ mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2,
+ };
+
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline_fa_mask_opt,
+ { mask_buf, mask_opt_buf }, opt_pc,
+ { mask_opt_num_dwords, CEIL_DIV(nem1, Br), nem2 * nem3 });
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
+
const vk_flash_attn_push_constants pc = { N, KV,
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
(uint32_t)neq2, (uint32_t)neq3,
gqa_ratio, split_kv, split_k };
if (split_k > 1) {
+ ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
+
if (ctx->prealloc_split_k_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
workgroups_x *= pipeline->wg_denoms[0];
vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
- {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf},
+ {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf, mask_opt_buf},
// We only use split_k when group query attention is enabled, which means
// there's no more than one tile of rows (i.e. workgroups_x would have been
// one). We reuse workgroups_x to mean the number of splits, so we need to
workgroups_x *= pipeline->wg_denoms[0];
}
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
- {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf},
+ {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf, mask_opt_buf},
pc, { workgroups_x, workgroups_y, workgroups_z });
}
}
return elem;
}
-shared float tmpsh[row_split];
-
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
shared f16vec4 Qf[Br * qstride];
}
}
+ const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);
+ // mo_offset will point to the tile starting at row i*Br and col 0
+ uint32_t mo_offset = mo_stride * i;
+
#if BLOCK_SIZE > 1
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
uint32_t m_offset = gqa_iq1*KV;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
+ mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;
}
+ uint32_t mask_opt = 0;
+ uint32_t mask_opt_idx = ~0;
+
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize];
+ [[unroll]] for (uint32_t idx = 0; idx < mask_cache.length(); ++idx) {
+ mask_cache[idx] = f16vec4(0);
+ }
+
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
- bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
- float max_mask = NEG_FLT_MAX_OVER_2;
- [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
- uint32_t c = (idx + tid) / (Br / 4);
- uint32_t r = (idx + tid) % (Br / 4);
- if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
- if ((!KV_bounds_check || j * Bc + c < KV)) {
- f16vec4 m;
- if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) {
- m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
- data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
- data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
- data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]);
- max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3]));
- } else if (i * Br + r * 4 + 2 < p.nem1) {
- m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
- data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
- data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
- 0.0);
- max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2]));
- } else if (i * Br + r * 4 + 1 < p.nem1) {
- m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
- data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
- 0.0,
- 0.0);
- max_mask = max(max(max_mask, float(m[0])), float(m[1]));
- } else if (i * Br + r * 4 < p.nem1) {
- m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
- 0.0,
- 0.0,
- 0.0);
- max_mask = max(max_mask, float(m[0]));
- } else {
- m = f16vec4(0.0);
+ if (USE_MASK_OPT && mask_opt_idx != j / 16) {
+ mask_opt_idx = j / 16;
+ mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
+ }
+ uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
+ if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
+ // skip this block
+ continue;
+ }
+ // Only load if the block is not all zeros
+ if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
+ bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
+
+ float max_mask = NEG_FLT_MAX_OVER_2;
+ [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
+ uint32_t c = (idx + tid) / (Br / 4);
+ uint32_t r = (idx + tid) % (Br / 4);
+ if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
+ if ((!KV_bounds_check || j * Bc + c < KV)) {
+ f16vec4 m;
+ if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) {
+ m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
+ data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
+ data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
+ data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]);
+ max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3]));
+ } else if (i * Br + r * 4 + 2 < p.nem1) {
+ m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
+ data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
+ data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
+ 0.0);
+ max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2]));
+ } else if (i * Br + r * 4 + 1 < p.nem1) {
+ m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
+ data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
+ 0.0,
+ 0.0);
+ max_mask = max(max(max_mask, float(m[0])), float(m[1]));
+ } else if (i * Br + r * 4 < p.nem1) {
+ m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
+ 0.0,
+ 0.0,
+ 0.0);
+ max_mask = max(max_mask, float(m[0]));
+ } else {
+ m = f16vec4(0.0);
+ }
+ mask_cache[idx / WorkGroupSize] = m;
}
- mask_cache[idx / WorkGroupSize] = m;
}
}
}
- // skip the block if the mask is entirely -inf
- bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
- barrier();
- if (gl_SubgroupInvocationID == 0) {
- tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
- }
- barrier();
- [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
- max_mask = max(max_mask, tmpsh[s]);
- }
- if (max_mask <= NEG_FLT_MAX_OVER_2) {
- continue;
- }
}
if (K_LOAD_SHMEM != 0) {
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
}
+ const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);
+ // mo_offset will point to the tile starting at row i*Br and col 0
+ uint32_t mo_offset = mo_stride * i;
+
uint32_t m_offset = gqa_iq1*KV * 2 /*sizeof(float16_t)*/;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
+ mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;
}
+ uint32_t mask_opt = 0;
+ uint32_t mask_opt_idx = ~0;
+
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
- coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
- bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
-
- if (nem1_bounds_check) {
- tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
- tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
- tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
- tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
-
- coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
-
- coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
-
- // skip the block if the mask is entirely -inf
- coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
- if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
- continue;
- }
- } else {
- tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
- // Don't clamp against nem1 when GQA is enabled
- uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
- tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
- tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
-
- coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
- coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
-
- // skip the block if the mask is entirely -inf
- coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
- if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
- continue;
+ if (USE_MASK_OPT && mask_opt_idx != j / 16) {
+ mask_opt_idx = j / 16;
+ mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
+ }
+ uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
+ if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
+ // skip this block
+ continue;
+ }
+ // Only load if the block is not all zeros
+ if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
+ bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
+
+ if (nem1_bounds_check) {
+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
+ tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
+ tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
+ tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
+
+ coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
+ } else {
+ tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
+ // Don't clamp against nem1 when GQA is enabled
+ uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
+ tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
+ tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
+
+ coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
}
}
}
--- /dev/null
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_EXT_shader_16bit_storage : enable
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 128;
+layout (constant_id = 1) const uint NUM_SUBGROUPS = 4;
+layout (constant_id = 2) const uint Br = 32;
+layout (constant_id = 3) const uint Bc = 32;
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {float16_t data_a[];};
+layout (binding = 0) readonly buffer Av4 {f16vec4 data_av4[];};
+layout (binding = 1) writeonly buffer D {uint data_d[];};
+
+layout (push_constant) uniform parameter {
+ uint nem0;
+ uint nem1;
+ uint nem2;
+ uint nbm1;
+ uint nbm2;
+ uint nbm3;
+ uint nbd1;
+ uint nbd2;
+ uint nbd3;
+};
+
+#define MASK_OPT_ALL_NEG_INF 1
+#define MASK_OPT_ALL_ZERO 2
+
+shared float minsh[NUM_SUBGROUPS];
+shared float maxsh[NUM_SUBGROUPS];
+
+// For each Br x Bc block of the mask (input) buffer, read all values and check
+// if it's all -inf or all zero. Write out a two-bit code indicating which it is
+// (or zero for neither). Each workgroup processes 16 tiles and writes out a
+// 32-bit result mask.
+//
+// TODO: This is a lot of work per workgroup, might make sense to split this into
+// more workgroups in the future.
+void main() {
+ // Each workgroup handles a row
+ const uint tid = gl_LocalInvocationIndex;
+ const uint i0 = gl_WorkGroupID.x;
+ const uint i1 = gl_WorkGroupID.y;
+ const uint i2 = gl_WorkGroupID.z % nem2;
+ const uint i3 = gl_WorkGroupID.z / nem2;
+
+ float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF);
+
+ uint result = 0;
+
+ // Fast path for fully in-bounds blocks where we can do f16vec4 loads
+ if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 &&
+ ((Br * Bc) % (BLOCK_SIZE * 4)) == 0) {
+ [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
+ float min_v = FLT_MAX_OVER_2;
+ float max_v = -FLT_MAX_OVER_2;
+ [[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) {
+ uint j0 = (i + tid) % (Bc / 4);
+ uint j1 = (i + tid) / (Bc / 4);
+
+ j0 *= 4;
+ j0 += (i0 * 16 + block_x) * Bc;
+ j1 += i1 * Br;
+
+ vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]);
+ [[unroll]] for (int c = 0; c < 4; ++c) {
+ min_v = min(min_v, f[c]);
+ max_v = max(max_v, f[c]);
+ }
+ }
+ min_v = subgroupMin(min_v);
+ max_v = subgroupMax(max_v);
+ if (gl_SubgroupInvocationID == 0) {
+ minsh[gl_SubgroupID] = min_v;
+ maxsh[gl_SubgroupID] = max_v;
+ }
+ barrier();
+ if (tid == 0) {
+ [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
+ min_v = min(min_v, minsh[i]);
+ max_v = max(max_v, maxsh[i]);
+ }
+ if (max_v <= -FLT_MAX_OVER_2) {
+ result |= 1 << (2*block_x);
+ }
+ if (min_v == 0.0f && max_v == 0.0f) {
+ result |= 2 << (2*block_x);
+ }
+ }
+ barrier();
+ }
+ } else {
+ [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
+ float min_v = FLT_MAX_OVER_2;
+ float max_v = -FLT_MAX_OVER_2;
+ [[unroll]] for (uint i = 0; i < Br * Bc; i += BLOCK_SIZE) {
+ if ((Br * Bc % BLOCK_SIZE) != 0 && i + tid >= Br * Bc) {
+ continue;
+ }
+ uint j0 = (i + tid) % Bc;
+ uint j1 = (i + tid) / Bc;
+
+ j0 += (i0 * 16 + block_x) * Bc;
+ j1 += i1 * Br;
+
+ if (j0 < nem0 && j1 < nem1) {
+ float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]);
+ min_v = min(min_v, f);
+ max_v = max(max_v, f);
+ }
+ }
+ min_v = subgroupMin(min_v);
+ max_v = subgroupMax(max_v);
+ if (gl_SubgroupInvocationID == 0) {
+ minsh[gl_SubgroupID] = min_v;
+ maxsh[gl_SubgroupID] = max_v;
+ }
+ barrier();
+ if (tid == 0) {
+ [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
+ min_v = min(min_v, minsh[i]);
+ max_v = max(max_v, maxsh[i]);
+ }
+ if (max_v <= -FLT_MAX_OVER_2) {
+ result |= 1 << (2*block_x);
+ }
+ if (min_v == 0.0f && max_v == 0.0f) {
+ result |= 2 << (2*block_x);
+ }
+ }
+ barrier();
+ }
+ }
+
+ if (tid == 0) {
+ data_d[i0 + i1 * nbd1 + i2 * nbd2 + i3 * nbd3] = result;
+ }
+}