uint32_t nev3;
uint32_t nem1;
+ uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
+ uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
+ uint32_t nb21;
uint32_t nb22;
uint32_t nb23;
uint32_t nb31;
}
assert(pipelines);
- bool aligned = (KV % pipelines[1]->align) == 0;
+ const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
+ const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
+ const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
+
+ bool aligned = (KV % pipelines[1]->align) == 0 &&
+ // the "aligned" shader variant will forcibly align strides, for performance
+ (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
+
vk_pipeline pipeline = pipelines[aligned];
assert(pipeline);
if (ctx->device->uma) {
ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
- ggml_vk_host_get(ctx->device, k->data, d_K, q_buf_offset);
- ggml_vk_host_get(ctx->device, v->data, d_V, q_buf_offset);
- ggml_vk_host_get(ctx->device, dst->data, d_D, q_buf_offset);
+ ggml_vk_host_get(ctx->device, k->data, d_K, k_buf_offset);
+ ggml_vk_host_get(ctx->device, v->data, d_V, v_buf_offset);
+ ggml_vk_host_get(ctx->device, dst->data, d_D, d_buf_offset);
Q_uma = d_Q != nullptr;
K_uma = d_K != nullptr;
V_uma = d_V != nullptr;
D_uma = d_D != nullptr;
if (mask) {
- ggml_vk_host_get(ctx->device, mask->data, d_M, q_buf_offset);
+ ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset);
M_uma = d_M != nullptr;
}
}
}
}
- const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, nem1, (uint32_t)nbq2, (uint32_t)nbq3, (uint32_t)nbk2, (uint32_t)nbk3, (uint32_t)nbv2, (uint32_t)nbv3, nbm1, scale, max_bias, logit_softcap, mask != nullptr, n_head_log2, m0, m1 };
+ const vk_flash_attn_push_constants pc = { N, KV,
+ (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
+ (uint32_t)neq2, (uint32_t)neq3,
+ (uint32_t)nek2, (uint32_t)nek3,
+ (uint32_t)nev2, (uint32_t)nev3,
+ nem1,
+ q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
+ k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
+ v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
+ nbm1,
+ scale, max_bias, logit_softcap,
+ mask != nullptr, n_head_log2, m0, m1 };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
ggml_tensor * src0 = tensor->src[0];
ggml_tensor * src1 = tensor->src[1];
ggml_tensor * src2 = tensor->src[2];
+ ggml_tensor * src3 = tensor->src[3];
void * tensor_data = tensor->data;
if (src2 != nullptr) {
std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
}
+ if (src3 != nullptr) {
+ std::cerr << "src3=" << src3 << " src3->name=" << src3->name << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
+ }
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
std::cerr << std::endl << "Result:" << std::endl;
ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3);
if (src2 != nullptr) {
std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
}
+ if (src3 != nullptr) {
+ std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
+ }
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
std::cerr << std::endl << "Result:" << std::endl;
ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
if (src2 != nullptr) {
std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
}
+ if (src3 != nullptr) {
+ std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
+ }
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
std::cerr << std::endl << "Result:" << std::endl;
ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]);
uint32_t nev3;
uint32_t nem1;
+ uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
+ uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
+ uint32_t nb21;
uint32_t nb22;
uint32_t nb23;
uint32_t nb31;
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
+ // nb?1 are already divided by the type size and are in units of elements
+ uint32_t q_stride = p.nb01;
+ uint32_t k_stride = p.nb11;
+ uint32_t v_stride = p.nb21;
+ // hint to the compiler that strides are aligned for the aligned variant of the shader
+ if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
+ {
+ q_stride &= ~7;
+#if !defined(BLOCK_SIZE)
+ k_stride &= ~7;
+ v_stride &= ~7;
+#endif
+ }
+ tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
+ tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
+ tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
+
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Q;
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
const float logit_softcap; // Gemma 2
const ggml_type type_KV;
+ std::array<int32_t, 4> permute;
std::string vars() override {
- return VARS_TO_STR8(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV);
+ return VARS_TO_STR9(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, permute);
}
double max_nmse_err() override {
}
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8,
- bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16)
- : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV) {}
+ bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16,
+ std::array<int32_t, 4> permute = {0, 1, 2, 3})
+ : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
- ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs_padded, nb, nh, 1);
+ auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * {
+ int64_t ne[4] = {ne0, ne1, ne2, ne3};
+ int64_t ne_perm[4];
+ for (int i = 0; i < 4; ++i) {
+ ne_perm[permute[i]] = ne[i];
+ }
+ ggml_tensor * t = ggml_new_tensor_4d(ctx, type, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3]);
+ if (permute != std::array<int32_t, 4>{0, 1, 2, 3}) {
+ t = ggml_permute(ctx, t, permute[0], permute[1], permute[2], permute[3]);
+ }
+ return t;
+ };
+
+ ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh, 1);
ggml_set_name(q, "q");
- ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1);
+ ggml_tensor * k = create_permuted(type_KV, hs_padded, kv, nh, 1);
ggml_set_name(k, "k");
- ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1);
+ ggml_tensor * v = create_permuted(type_KV, hs_padded, kv, nh, 1);
ggml_set_name(v, "v");
ggml_tensor * m = nullptr;
for (int nb : { 1, 3, 32, 35, }) {
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
+ // run fewer test cases permuted
+ if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
+ test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3}));
+ }
}
}
}