memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
- // TODO: handle transposed/permuted matrices
-
const int ith = params->ith;
const int nth = params->nth;
GGML_TENSOR_UNARY_OP_LOCALS
- //const int64_t ne11 = src1 ? src1->ne[1] : 1;
+ const int64_t nb11 = src1 ? src1->nb[1] : 1;
+ const int64_t nb12 = src1 ? src1->nb[2] : 1;
+ const int64_t nb13 = src1 ? src1->nb[3] : 1;
+
+ const int64_t ne12 = src1 ? src1->ne[2] : 1;
+ const int64_t ne13 = src1 ? src1->ne[3] : 1;
// TODO: is this supposed to be ceil instead of floor?
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
- const int nc = src0->ne[0];
- const int nr = ggml_nrows(src0);
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
+ float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
- for (int i1 = ir0; i1 < ir1; i1++) {
- // ALiBi
- const uint32_t h = (i1/ne01)%ne02; // head
- const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
-
- float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
- float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
-
- // broadcast the mask across rows
- ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
- float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
-
- ggml_vec_cpy_f32 (nc, wp, sp);
- ggml_vec_scale_f32(nc, wp, scale);
- if (mp_f32) {
- if (use_f16) {
- for (int i = 0; i < nc; ++i) {
- wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
- }
- } else {
- for (int i = 0; i < nc; ++i) {
- wp[i] += slope*mp_f32[i];
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
+ const int64_t i11 = i01;
+ const int64_t i12 = i02%ne12;
+ const int64_t i13 = i03%ne13;
+
+ // ALiBi
+ const uint32_t h = i02; // head
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
+
+ float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+ float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+ // broadcast the mask across rows
+ ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
+ float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
+
+ ggml_vec_cpy_f32 (ne00, wp, sp);
+ ggml_vec_scale_f32(ne00, wp, scale);
+ if (mp_f32) {
+ if (use_f16) {
+ for (int i = 0; i < ne00; ++i) {
+ wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
+ }
+ } else {
+ for (int i = 0; i < ne00; ++i) {
+ wp[i] += slope*mp_f32[i];
+ }
+ }
}
- }
- }
#ifndef NDEBUG
- for (int i = 0; i < nc; ++i) {
- //printf("p[%d] = %f\n", i, p[i]);
- assert(!isnan(wp[i]));
- }
+ for (int i = 0; i < ne00; ++i) {
+ //printf("p[%d] = %f\n", i, p[i]);
+ assert(!isnan(wp[i]));
+ }
#endif
- float max = -INFINITY;
- ggml_vec_max_f32(nc, &max, wp);
+ float max = -INFINITY;
+ ggml_vec_max_f32(ne00, &max, wp);
- ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
- assert(sum > 0.0);
+ ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
+ assert(sum > 0.0);
- sum = 1.0/sum;
- ggml_vec_scale_f32(nc, dp, sum);
+ sum = 1.0/sum;
+ ggml_vec_scale_f32(ne00, dp, sum);
#ifndef NDEBUG
- for (int i = 0; i < nc; ++i) {
- assert(!isnan(dp[i]));
- assert(!isinf(dp[i]));
- }
+ for (int i = 0; i < ne00; ++i) {
+ assert(!isnan(dp[i]));
+ assert(!isinf(dp[i]));
+ }
#endif
+ }
+ }
}
}
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
- ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
+ ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
memset(VKQ32, 0, DV*sizeof(float));
}
- const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
+ const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq3%mask->ne[2])*mask->nb[2]) : NULL;
// k indices
const int ik3 = iq3 / rk3;
device char * dst,
constant ggml_metal_kargs_soft_max & args,
threadgroup float * buf [[threadgroup(0)]],
- uint tgpig[[threadgroup_position_in_grid]],
- uint tpitg[[thread_position_in_threadgroup]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
- uint ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
- const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
- const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
+ uint3 tptg[[threads_per_threadgroup]]) {
+ const int32_t i03 = tgpig.z;
+ const int32_t i02 = tgpig.y;
+ const int32_t i01 = tgpig.x;
+
+ const int32_t i13 = i03%args.ne13;
+ const int32_t i12 = i02%args.ne12;
+ const int32_t i11 = i01;
- device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr;
- device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
+ device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
+ device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
+ device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
float slope = 1.0f;
// ALiBi
if (args.max_bias > 0.0f) {
- const int64_t h = i02;
+ const int32_t h = i02;
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
// parallel max
float lmax = -INFINITY;
- for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
}
// find the max value in the block
float max_val = simd_max(lmax);
- if (ntg > N_SIMDWIDTH) {
+ if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = -INFINITY;
}
// parallel sum
float lsum = 0.0f;
- for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
lsum += exp_psrc0;
pdst[i00] = exp_psrc0;
float sum = simd_sum(lsum);
- if (ntg > N_SIMDWIDTH) {
+ if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
const float inv_sum = 1.0f/sum;
- for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
pdst[i00] *= inv_sum;
}
}
device char * dst,
constant ggml_metal_kargs_soft_max & args,
threadgroup float * buf [[threadgroup(0)]],
- uint tgpig[[threadgroup_position_in_grid]],
- uint tpitg[[thread_position_in_threadgroup]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
- uint ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
- const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
- const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
+ uint3 tptg[[threads_per_threadgroup]]) {
+ const int32_t i03 = tgpig.z;
+ const int32_t i02 = tgpig.y;
+ const int32_t i01 = tgpig.x;
+
+ const int32_t i13 = i03%args.ne13;
+ const int32_t i12 = i02%args.ne12;
+ const int32_t i11 = i01;
- device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr;
- device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
+ device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
+ device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
+ device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
float slope = 1.0f;
if (args.max_bias > 0.0f) {
- const int64_t h = i02;
+ const int32_t h = i02;
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
// parallel max
float4 lmax4 = -INFINITY;
- for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
+ for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
}
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
float max_val = simd_max(lmax);
- if (ntg > N_SIMDWIDTH) {
+ if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = -INFINITY;
}
// parallel sum
float4 lsum4 = 0.0f;
- for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
+ for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
float sum = simd_sum(lsum);
- if (ntg > N_SIMDWIDTH) {
+ if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
const float inv_sum = 1.0f/sum;
- for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
+ for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
pdst4[i00] *= inv_sum;
}
}
// load the mask in shared memory
#pragma unroll(Q)
for (short j = 0; j < Q; ++j) {
- device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
+ device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq3%args.ne32)*args.nb32);
const float m = pm[ic + tiisg];
const bool has_mask = mask != q;
// pointer to the mask
- device const half * pm = (device const half *) (mask + iq1*args.nb31);
+ device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq3%args.ne32)*args.nb32);
float slope = 1.0f;
const std::array<int64_t, 4> ne;
const bool mask;
const ggml_type m_prec;
+ const std::array<int64_t, 2> nr23; // broadcast only dims 2 and 3
const float scale;
const float max_bias;
std::string vars() override {
- return VARS_TO_STR6(type, ne, mask, m_prec, scale, max_bias);
+ return VARS_TO_STR7(type, ne, mask, m_prec, nr23, scale, max_bias);
}
// the 1024 test with bias occasionally fails:
std::array<int64_t, 4> ne = {10, 5, 4, 3},
bool mask = false,
ggml_type m_prec = GGML_TYPE_F32,
+ std::array<int64_t, 2> nr23 = {1, 1},
float scale = 1.0f,
float max_bias = 0.0f)
- : type(type), ne(ne), mask(mask), m_prec(m_prec), scale(scale), max_bias(max_bias) {}
+ : type(type), ne(ne), mask(mask), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
- ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+ ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * mask = nullptr;
if (this->mask) {
- mask = ggml_new_tensor_2d(ctx, m_prec, ne[0], ne[1]);
+ mask = ggml_new_tensor_4d(ctx, m_prec, ne[0], ne[1], ne[2], ne[3]);
ggml_set_name(mask, "mask");
}
const int64_t hsk; // K head size
const int64_t hsv; // V head size
const int64_t nh; // num heads
- const int64_t nr; // repeat in Q, tests for grouped-query attention
+ const std::array<int64_t, 2> nr23; // repeat in dim 2 and 3, tests for grouped-query attention
const int64_t kv; // kv size
const int64_t nb; // batch size
std::array<int32_t, 4> permute;
std::string vars() override {
- return VARS_TO_STR12(hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
+ return VARS_TO_STR12(hsk, hsv, nh, nr23, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
}
double max_nmse_err() override {
GGML_UNUSED(t);
// Just counting matmul costs:
// Q*K^T is nb x hsk x kv, P*V is nb x kv x hsv, per head
- return 2 * nh*nr * nb * (hsk + hsv) * kv;
+ return (2 * nh*nr23[0] * nb * (hsk + hsv) * kv)*nr23[1];
}
- test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
+ test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, std::array<int64_t, 2> nr23 = {1, 1}, int64_t kv = 96, int64_t nb = 8,
bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3})
- : hsk(hsk), hsv(hsv), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
+ : hsk(hsk), hsv(hsv), nh(nh), nr23(nr23), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV));
return t;
};
- ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr, 1);
+ ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0], nr23[1]);
ggml_set_name(q, "q");
- ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, 1);
+ ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, nr23[1]);
ggml_set_name(k, "k");
- ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, 1);
+ ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1]);
ggml_set_name(v, "v");
ggml_tensor * m = nullptr;
if (mask) {
- m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1);
+ m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[1], 1);
ggml_set_name(m, "m");
}
for (int64_t ne1 : {16, 1024}) {
if (mask) {
for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, scale, max_bias));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, scale, max_bias));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias));
+
+ if (ne0 <= 32 && ne1 <= 32) {
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, {3, 1}, scale, max_bias));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {2, 3}, scale, max_bias));
+ }
}
} else {
/* The precision of mask here doesn't matter as boolean mask is false */
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, GGML_TYPE_F32, scale, max_bias));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, GGML_TYPE_F32, scale, max_bias));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, GGML_TYPE_F32, {1, 1}, scale, max_bias));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, GGML_TYPE_F32, {1, 1}, scale, max_bias));
}
}
}
}
}
}
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, GGML_TYPE_F32, 0.1f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 8.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 8.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));
for (float max_bias : {0.0f, 8.0f}) {
for (float scale : {1.0f, 0.1f}) {
for (float logit_softcap : {0.0f, 10.0f}) {
if (hsk != 128 && logit_softcap != 0.0f) continue;
for (int nh : { 4, }) {
- for (int nr : { 1, 4, 16 }) {
- if (nr == 16 && hsk != 128) continue;
- for (int kv : { 512, 1024, }) {
- if (nr != 1 && kv != 512) continue;
- for (int nb : { 1, 3, 32, 35, }) {
- for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
- if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
- for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
- test_cases.emplace_back(new test_flash_attn_ext(
- hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
- // run fewer test cases permuted
- if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
+ for (int nr3 : { 1, 3, }) {
+ if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes
+ for (int nr2 : { 1, 4, 16 }) {
+ if (nr2 == 16 && hsk != 128) continue;
+ for (int kv : { 512, 1024, }) {
+ if (nr2 != 1 && kv != 512) continue;
+ for (int nb : { 1, 3, 32, 35, }) {
+ for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
+ if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
+ for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
test_cases.emplace_back(new test_flash_attn_ext(
- hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
+ hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
+ // run fewer test cases permuted
+ if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
+ test_cases.emplace_back(new test_flash_attn_ext(
+ hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
+ }
}
}
}
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1}));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
for (int kv : { 4096, 8192, 16384, }) {
for (int hs : { 64, 128, }) {
for (int nr : { 1, 4, }) {
- test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, nr, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
+ test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, {nr, 1}, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
}
}
}