VERSION ${GGML_INSTALL_VERSION}
COMPATIBILITY SameMajorVersion)
-target_compile_definitions(ggml-base PRIVATE
- GGML_VERSION="${GGML_INSTALL_VERSION}"
- GGML_COMMIT="${GGML_BUILD_COMMIT}"
-)
-message(STATUS "ggml version: ${GGML_INSTALL_VERSION}")
-message(STATUS "ggml commit: ${GGML_BUILD_COMMIT}")
-
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake
${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml)
// misc
- GGML_API const char * ggml_version(void);
- GGML_API const char * ggml_commit(void);
-
GGML_API void ggml_time_init(void); // call this once at the beginning of the program
GGML_API int64_t ggml_time_ms(void);
GGML_API int64_t ggml_time_us(void);
struct ggml_context * ctx,
struct ggml_tensor * a);
+ // a [ne0, ne01, ne02, ne03]
+ // mask [ne0, ne11, ne12, ne13] | ne11 >= ne01, F16 or F32, optional
+ //
+ // broadcast:
+ // ne02 % ne12 == 0
+ // ne03 % ne13 == 0
+ //
// fused soft_max(a*scale + mask*(ALiBi slope))
- // mask is optional
// max_bias = 0.0f for no ALiBi
GGML_API struct ggml_tensor * ggml_soft_max_ext(
struct ggml_context * ctx,
#define GGML_KQ_MASK_PAD 64
- // q: [n_embd_k, n_batch, n_head, 1]
- // k: [n_embd_k, n_kv, n_head_kv, 1]
- // v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !!
- // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
- // res: [n_embd_v, n_head, n_batch, 1] !! permuted !!
+ // q: [n_embd_k, n_batch, n_head, ne3]
+ // k: [n_embd_k, n_kv, n_head_kv, ne3]
+ // v: [n_embd_v, n_kv, n_head_kv, ne3] !! not transposed !!
+ // mask: [n_kv, n_batch_pad, ne32, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
+ // res: [n_embd_v, n_head, n_batch, ne3] !! permuted !!
+ //
+ // broadcast:
+ // n_head % n_head_kv == 0
+ // ne3 % ne32 == 0
+ //
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
struct ggml_context * ctx,
struct ggml_tensor * q,
case GGML_OP_SQRT:
case GGML_OP_CLAMP:
case GGML_OP_DIAG_MASK_INF:
- case GGML_OP_SOFT_MAX:
case GGML_OP_SUM_ROWS:
case GGML_OP_ARGSORT:
case GGML_OP_ACC:
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_COUNT_EQUAL:
return true;
+ case GGML_OP_SOFT_MAX:
+ // TODO: support broadcast
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14435
+ return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
case GGML_OP_FLASH_ATTN_EXT:{
// derived from [ggml-cuda.cu]
if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
// DeepSeek MLA
return false;
}
+ // TODO: support broadcast
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14435
if (op->src[0]->ne[3] != 1) {
return false;
}
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
- // TODO: handle transposed/permuted matrices
-
const int ith = params->ith;
const int nth = params->nth;
GGML_TENSOR_UNARY_OP_LOCALS
- //const int64_t ne11 = src1 ? src1->ne[1] : 1;
+ const int64_t nb11 = src1 ? src1->nb[1] : 1;
+ const int64_t nb12 = src1 ? src1->nb[2] : 1;
+ const int64_t nb13 = src1 ? src1->nb[3] : 1;
+
+ const int64_t ne12 = src1 ? src1->ne[2] : 1;
+ const int64_t ne13 = src1 ? src1->ne[3] : 1;
// TODO: is this supposed to be ceil instead of floor?
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
- const int nc = src0->ne[0];
- const int nr = ggml_nrows(src0);
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
+ float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
- for (int i1 = ir0; i1 < ir1; i1++) {
- // ALiBi
- const uint32_t h = (i1/ne01)%ne02; // head
- const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
-
- float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
- float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
-
- // broadcast the mask across rows
- ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
- float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
-
- ggml_vec_cpy_f32 (nc, wp, sp);
- ggml_vec_scale_f32(nc, wp, scale);
- if (mp_f32) {
- if (use_f16) {
- for (int i = 0; i < nc; ++i) {
- wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
- }
- } else {
- for (int i = 0; i < nc; ++i) {
- wp[i] += slope*mp_f32[i];
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
+ const int64_t i11 = i01;
+ const int64_t i12 = i02%ne12;
+ const int64_t i13 = i03%ne13;
+
+ // ALiBi
+ const uint32_t h = i02; // head
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
+
+ float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+ float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+ // broadcast the mask across rows
+ ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
+ float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
+
+ ggml_vec_cpy_f32 (ne00, wp, sp);
+ ggml_vec_scale_f32(ne00, wp, scale);
+ if (mp_f32) {
+ if (use_f16) {
+ for (int i = 0; i < ne00; ++i) {
+ wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
+ }
+ } else {
+ for (int i = 0; i < ne00; ++i) {
+ wp[i] += slope*mp_f32[i];
+ }
+ }
}
- }
- }
#ifndef NDEBUG
- for (int i = 0; i < nc; ++i) {
- //printf("p[%d] = %f\n", i, p[i]);
- assert(!isnan(wp[i]));
- }
+ for (int i = 0; i < ne00; ++i) {
+ //printf("p[%d] = %f\n", i, p[i]);
+ assert(!isnan(wp[i]));
+ }
#endif
- float max = -INFINITY;
- ggml_vec_max_f32(nc, &max, wp);
+ float max = -INFINITY;
+ ggml_vec_max_f32(ne00, &max, wp);
- ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
- assert(sum > 0.0);
+ ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
+ assert(sum > 0.0);
- sum = 1.0/sum;
- ggml_vec_scale_f32(nc, dp, sum);
+ sum = 1.0/sum;
+ ggml_vec_scale_f32(ne00, dp, sum);
#ifndef NDEBUG
- for (int i = 0; i < nc; ++i) {
- assert(!isnan(dp[i]));
- assert(!isinf(dp[i]));
- }
+ for (int i = 0; i < ne00; ++i) {
+ assert(!isnan(dp[i]));
+ assert(!isinf(dp[i]));
+ }
#endif
+ }
+ }
}
}
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
- ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
+ ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
memset(VKQ32, 0, DV*sizeof(float));
}
- const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
+ const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq3%mask->ne[2])*mask->nb[2]) : NULL;
// k indices
const int ik3 = iq3 / rk3;
case GGML_OP_CONT:
return op->src[0]->type != GGML_TYPE_BF16;
case GGML_OP_DIAG_MASK_INF:
- case GGML_OP_SOFT_MAX:
return true;
+ case GGML_OP_SOFT_MAX:
+ // TODO: support batching
+ if (op->src[0]->ne[3] != 1) {
+ return false;
+ }
+ // TODO: support broadcast
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14435
+ return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
case GGML_OP_SOFT_MAX_BACK: {
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
if (op->src[0]->ne[0] == 192) {
return false;
}
+ // TODO: support broadcast
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14435
if (op->src[0]->ne[3] != 1) {
return false;
}
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
+ int32_t ne32;
uint64_t nb31;
+ uint64_t nb32;
int32_t ne1;
int32_t ne2;
float scale;
} ggml_metal_kargs_sum_rows;
typedef struct {
- int64_t ne00;
- int64_t ne01;
- int64_t ne02;
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne11;
+ int32_t ne12;
+ int32_t ne13;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
float scale;
float max_bias;
float m0;
case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:
case GGML_OP_GROUP_NORM:
- return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
+ return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
- const int64_t nrows_x = ggml_nrows(src0);
- const int64_t nrows_y = src0->ne[1];
-
- const uint32_t n_head = nrows_x/nrows_y;
+ const uint32_t n_head = src0->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne11 =*/ ne11,
+ /*.ne12 =*/ ne12,
+ /*.ne13 =*/ ne13,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
/*.scale =*/ scale,
/*.max_bias =*/ max_bias,
/*.m0 =*/ m0,
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_DIAG_MASK_INF:
{
/*.nb21 =*/ nb21,
/*.nb22 =*/ nb22,
/*.nb23 =*/ nb23,
+ /*.ne32 =*/ ne32,
/*.nb31 =*/ nb31,
+ /*.nb32 =*/ nb32,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.scale =*/ scale,
device char * dst,
constant ggml_metal_kargs_soft_max & args,
threadgroup float * buf [[threadgroup(0)]],
- uint tgpig[[threadgroup_position_in_grid]],
- uint tpitg[[thread_position_in_threadgroup]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
- uint ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
- const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
- const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
+ uint3 tptg[[threads_per_threadgroup]]) {
+ const int32_t i03 = tgpig.z;
+ const int32_t i02 = tgpig.y;
+ const int32_t i01 = tgpig.x;
+
+ const int32_t i13 = i03%args.ne13;
+ const int32_t i12 = i02%args.ne12;
+ const int32_t i11 = i01;
- device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr;
- device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
+ device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
+ device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
+ device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
float slope = 1.0f;
// ALiBi
if (args.max_bias > 0.0f) {
- const int64_t h = i02;
+ const int32_t h = i02;
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
// parallel max
float lmax = -INFINITY;
- for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
}
// find the max value in the block
float max_val = simd_max(lmax);
- if (ntg > N_SIMDWIDTH) {
+ if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = -INFINITY;
}
// parallel sum
float lsum = 0.0f;
- for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
lsum += exp_psrc0;
pdst[i00] = exp_psrc0;
float sum = simd_sum(lsum);
- if (ntg > N_SIMDWIDTH) {
+ if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
const float inv_sum = 1.0f/sum;
- for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
pdst[i00] *= inv_sum;
}
}
device char * dst,
constant ggml_metal_kargs_soft_max & args,
threadgroup float * buf [[threadgroup(0)]],
- uint tgpig[[threadgroup_position_in_grid]],
- uint tpitg[[thread_position_in_threadgroup]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
- uint ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
- const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
- const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
+ uint3 tptg[[threads_per_threadgroup]]) {
+ const int32_t i03 = tgpig.z;
+ const int32_t i02 = tgpig.y;
+ const int32_t i01 = tgpig.x;
+
+ const int32_t i13 = i03%args.ne13;
+ const int32_t i12 = i02%args.ne12;
+ const int32_t i11 = i01;
- device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr;
- device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
+ device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
+ device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
+ device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
float slope = 1.0f;
if (args.max_bias > 0.0f) {
- const int64_t h = i02;
+ const int32_t h = i02;
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
// parallel max
float4 lmax4 = -INFINITY;
- for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
+ for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
}
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
float max_val = simd_max(lmax);
- if (ntg > N_SIMDWIDTH) {
+ if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = -INFINITY;
}
// parallel sum
float4 lsum4 = 0.0f;
- for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
+ for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
float sum = simd_sum(lsum);
- if (ntg > N_SIMDWIDTH) {
+ if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
const float inv_sum = 1.0f/sum;
- for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
+ for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
pdst4[i00] *= inv_sum;
}
}
// load the mask in shared memory
#pragma unroll(Q)
for (short j = 0; j < Q; ++j) {
- device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
+ device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq3%args.ne32)*args.nb32);
const float m = pm[ic + tiisg];
const bool has_mask = mask != q;
// pointer to the mask
- device const half * pm = (device const half *) (mask + iq1*args.nb31);
+ device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq3%args.ne32)*args.nb32);
float slope = 1.0f;
return true;
case GGML_OP_CONT:
return op->src[0]->type != GGML_TYPE_BF16;
- case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
- return true;
+ // TODO: support batching
+ if (op->src[0]->ne[3] != 1) {
+ return false;
+ }
+ // TODO: support broadcast
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14435
+ return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
+ case GGML_OP_DIAG_MASK_INF:
case GGML_OP_ROPE:
case GGML_OP_IM2COL:
return true;
vk_pipeline pipeline_div_norepeat[2][2][2];
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
- vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
+ vk_pipeline pipeline_upscale_f32;
vk_pipeline pipeline_scale_f32;
vk_pipeline pipeline_sqr_f32;
vk_pipeline pipeline_sin_f32;
vk_pipeline pipeline_cos_f32;
vk_pipeline pipeline_clamp_f32;
vk_pipeline pipeline_pad_f32;
- vk_pipeline pipeline_roll_f32;
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16;
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
};
static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128");
-static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst, int64_t ne = 0) {
- GGML_ASSERT(ne != 0 || (ggml_nelements(src0) == ggml_nelements(dst)));
- ne = ne != 0 ? ne : ggml_nelements(dst);
- GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());
-
- vk_op_unary_push_constants p{};
- p.ne = (uint32_t)ne;
-
- size_t src0_tsize = ggml_type_size(src0->type);
- p.ne00 = (uint32_t)src0->ne[0];
- p.ne01 = (uint32_t)src0->ne[1];
- p.ne02 = (uint32_t)src0->ne[2];
- p.ne03 = (uint32_t)src0->ne[3];
- p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
- p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
- p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
- p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);
-
- size_t dst_tsize = ggml_type_size(dst->type);
- p.ne10 = (uint32_t)dst->ne[0];
- p.ne11 = (uint32_t)dst->ne[1];
- p.ne12 = (uint32_t)dst->ne[2];
- p.ne13 = (uint32_t)dst->ne[3];
- p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
- p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
- p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
- p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
-
- return p; // fastdiv values and offsets are initialized later in ggml_vk_op
-}
-
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
// Precompute mp (m' in the paper) and L such that division
// can be computed using a multiply (high 32b of 64b result)
struct vk_op_upscale_push_constants {
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
- uint32_t ne00; uint32_t ne01;
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
float sf0; float sf1; float sf2; float sf3;
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_ac_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_upscale_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
-
ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
}
return nullptr;
case GGML_OP_UPSCALE:
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- int mode = ggml_get_op_params_i32(dst, 0);
- switch (mode) {
- case GGML_SCALE_MODE_NEAREST:
- return ctx->device->pipeline_upscale_nearest_f32;
- case GGML_SCALE_MODE_BILINEAR:
- return ctx->device->pipeline_upscale_bilinear_f32;
- case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS:
- return ctx->device->pipeline_upscale_bilinear_ac_f32;
- }
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) {
+ return ctx->device->pipeline_upscale_f32;
}
return nullptr;
case GGML_OP_SCALE:
return ctx->device->pipeline_pad_f32;
}
return nullptr;
- case GGML_OP_ROLL:
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_roll_f32;
- }
- return nullptr;
case GGML_OP_REPEAT:
if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) {
return ctx->device->pipeline_repeat_f32;
case GGML_OP_COS:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
- case GGML_OP_ROLL:
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
case GGML_OP_CPY:
static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
- const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0);
- float sf0 = (float)dst->ne[0] / src0->ne[0];
- float sf1 = (float)dst->ne[1] / src0->ne[1];
- float sf2 = (float)dst->ne[2] / src0->ne[2];
- float sf3 = (float)dst->ne[3] / src0->ne[3];
-
- if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) {
- sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
- sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
- }
+ const float sf0 = (float)dst->ne[0] / src0->ne[0];
+ const float sf1 = (float)dst->ne[1] / src0->ne[1];
+ const float sf2 = (float)dst->ne[2] / src0->ne[2];
+ const float sf3 = (float)dst->ne[3] / src0->ne[3];
ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
(uint32_t)ggml_nelements(dst), 0, 0,
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1],
(uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
sf0, sf1, sf2, sf3,
}
static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
- p.param1 = ggml_get_op_params_f32(dst, 0);
+ float * op_params = (float *)dst->op_params;
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun);
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, {
+ (uint32_t)ggml_nelements(src0),
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
+ 0,
+ op_params[0], 0.0f,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ }, dryrun);
}
static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun);
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, {
+ (uint32_t)ggml_nelements(src0),
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
+ 0,
+ 0.0f, 0.0f,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ }, dryrun);
}
static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun);
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, {
+ (uint32_t)ggml_nelements(src0),
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
+ 0,
+ 0.0f, 0.0f,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ }, dryrun);
}
static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun);
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, {
+ (uint32_t)ggml_nelements(src0),
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
+ 0,
+ 0.0f, 0.0f,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ }, dryrun);
}
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
- p.param1 = ggml_get_op_params_f32(dst, 0);
- p.param2 = ggml_get_op_params_f32(dst, 1);
+ float * op_params = (float *)dst->op_params;
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun);
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, {
+ (uint32_t)ggml_nelements(src0),
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
+ 0,
+ op_params[0], op_params[1],
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ }, dryrun);
}
static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun);
-}
-
-static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- const int32_t s0 = ggml_get_op_params_i32(dst, 0);
- const int32_t s1 = ggml_get_op_params_i32(dst, 1);
- const int32_t s2 = ggml_get_op_params_i32(dst, 2);
- const int32_t s3 = ggml_get_op_params_i32(dst, 3);
- const uint32_t s01_packed = ((s0 + 0x8000) << 16) | (s1 + 0x8000);
- const uint32_t s23_packed = ((s2 + 0x8000) << 16) | (s3 + 0x8000);
-
- vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
- memcpy(&p.param1, &s01_packed, sizeof(float));
- memcpy(&p.param2, &s23_packed, sizeof(float));
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun);
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, {
+ (uint32_t)ggml_nelements(dst),
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
+ 0,
+ 0.0f, 0.0f,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ }, dryrun);
}
static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun);
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, {
+ (uint32_t)ggml_nelements(dst),
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
+ 0,
+ 0.0f, 0.0f,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ }, dryrun);
}
static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun);
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, {
+ (uint32_t)ggml_nelements(dst),
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
+ 0,
+ 0.0f, 0.0f,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ }, dryrun);
}
static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
}
}
- vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne);
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun);
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
+ ne,
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
+ 0,
+ 0.0f, 0.0f,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ }, dryrun);
}
static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
case GGML_OP_COS:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
- case GGML_OP_ROLL:
case GGML_OP_CPY:
case GGML_OP_CONT:
case GGML_OP_DUP:
case GGML_OP_PAD:
ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun);
- break;
- case GGML_OP_ROLL:
- ggml_vk_roll(ctx, compute_ctx, src0, node, dryrun);
-
break;
case GGML_OP_CPY:
case GGML_OP_CONT:
case GGML_OP_COS:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
- case GGML_OP_ROLL:
case GGML_OP_CPY:
case GGML_OP_CONT:
case GGML_OP_DUP:
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
return false;
}
+ // TODO: support broadcast
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14435
+ if (op->src[0]->ne[3] != 1) {
+ return false;
+ }
// It's straightforward to support different K/V dequant, but would
// significantly increase the number of pipelines
if (op->src[1]->type != op->src[2]->type) {
case GGML_OP_CLAMP:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_UPSCALE:
+ return op->op_params[0] == GGML_SCALE_MODE_NEAREST;
case GGML_OP_ACC:
case GGML_OP_CONCAT:
case GGML_OP_SCALE:
case GGML_OP_PAD:
- case GGML_OP_ROLL:
case GGML_OP_DIAG_MASK_INF:
+ return true;
case GGML_OP_SOFT_MAX:
+ // TODO: support batching
+ if (op->src[0]->ne[3] != 1) {
+ return false;
+ }
+ // TODO: support broadcast
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14435
+ return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_ARGSORT:
case GGML_OP_SUM:
layout (push_constant) uniform parameter
{
uint ne; uint a_offset; uint d_offset;
- uint ne00; uint ne01;
uint nb00; uint nb01; uint nb02; uint nb03;
uint ne10; uint ne11; uint ne12; uint ne13;
float sf0; float sf1; float sf2; float sf3;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-// from ggml.h: enum ggml_scale_mode, enum ggml_scale_flag
-#define NEAREST 0
-#define BILINEAR 1
-#define ALIGN_CORNERS (1 << 8)
-
-layout (constant_id = 0) const uint scale_mode = 0;
-
-float fetch_nearest(uint i10, uint i11, uint i12, uint i13) {
- const uint i00 = uint(i10 / p.sf0);
- const uint i01 = uint(i11 / p.sf1);
- const uint i02 = uint(i12 / p.sf2);
- const uint i03 = uint(i13 / p.sf3);
-
- return data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00];
-}
-
-float fetch_bilinear(ivec2 c0, ivec2 c1, vec2 d, uint i12, uint i13) {
- const uint i02 = uint(i12 / p.sf2);
- const uint i03 = uint(i13 / p.sf3);
- const uint base = p.a_offset + i03 * p.nb03 + i02 * p.nb02;
-
- const float v00 = data_a[base + c0.y * p.nb01 + c0.x * p.nb00];
- const float v01 = data_a[base + c0.y * p.nb01 + c1.x * p.nb00];
- const float v10 = data_a[base + c1.y * p.nb01 + c0.x * p.nb00];
- const float v11 = data_a[base + c1.y * p.nb01 + c1.x * p.nb00];
-
- return
- v00 * (1.0-d.x) * (1.0-d.y) +
- v01 * d.x * (1.0-d.y) +
- v10 * (1.0-d.x) * d.y +
- v11 * d.x * d.y;
-}
-
-float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) {
- const ivec2 ne0 = ivec2(p.ne00, p.ne01);
-
- const vec2 c = (vec2(i10, i11) + 0.5) / vec2(p.sf0, p.sf1) - 0.5;
- const vec2 c0f = floor(c);
- const vec2 d = c - c0f;
- const ivec2 c0 = max(ivec2(c0f), 0);
- const ivec2 c1 = min(ivec2(c0f + 1), ne0 - 1);
-
- return fetch_bilinear(c0, c1, d, i12, i13);
-}
-
-float interpolate_bilinear_align_corners(uint i10, uint i11, uint i12, uint i13) {
- const vec2 c = vec2(i10, i11) / vec2(p.sf0, p.sf1);
- const vec2 c0f = floor(c);
- const vec2 d = c - c0f;
- const ivec2 c0 = ivec2(c0f);
- const ivec2 c1 = c0 + 1;
-
- return fetch_bilinear(c0, c1, d, i12, i13);
-}
-
void main() {
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12;
const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13;
- float result;
- switch (scale_mode) {
- case NEAREST:
- result = fetch_nearest(i10, i11, i12, i13);
- break;
- case BILINEAR:
- result = interpolate_bilinear(i10, i11, i12, i13);
- break;
- case BILINEAR | ALIGN_CORNERS:
- result = interpolate_bilinear_align_corners(i10, i11, i12, i13);
- break;
- }
+ const uint i00 = uint(i10 / p.sf0);
+ const uint i01 = uint(i11 / p.sf1);
+ const uint i02 = uint(i12 / p.sf2);
+ const uint i03 = uint(i13 / p.sf3);
- data_d[p.d_offset + idx] = D_TYPE(result);
+ data_d[p.d_offset + idx] = D_TYPE(data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]);
}
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
- string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
-
for (auto &c : compiles) {
c.wait();
}
return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0;
}
-const char * ggml_version(void) {
- return GGML_VERSION;
-}
-
-const char * ggml_commit(void) {
- return GGML_COMMIT;
-}
-
//
// timing
//
if (mask) {
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(mask));
- GGML_ASSERT(ggml_is_matrix(mask));
+ GGML_ASSERT(ggml_is_3d(mask));
GGML_ASSERT(mask->ne[0] == a->ne[0]);
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
+ GGML_ASSERT(a->ne[2]%mask->ne[2] == 0);
+ GGML_ASSERT(a->ne[3]%mask->ne[3] == 0);
}
if (max_bias > 0.0f) {
GGML_ASSERT(ggml_can_mul_mat(k, q));
// TODO: check if vT can be multiplied by (k*qT)
+ GGML_ASSERT(q->ne[3] == k->ne[3]);
+ GGML_ASSERT(q->ne[3] == v->ne[3]);
+
if (mask) {
GGML_ASSERT(ggml_is_contiguous(mask));
- GGML_ASSERT(mask->ne[2] == 1);
- GGML_ASSERT(mask->ne[3] == 1);
+ GGML_ASSERT(mask->ne[2] == q->ne[3]);
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
+
+ GGML_ASSERT(q->ne[3] % mask->ne[2] == 0);
}
if (max_bias > 0.0f) {
const std::array<int64_t, 4> ne;
const bool mask;
const ggml_type m_prec;
+ const std::array<int64_t, 2> nr23; // broadcast only dims 2 and 3
const float scale;
const float max_bias;
std::string vars() override {
- return VARS_TO_STR6(type, ne, mask, m_prec, scale, max_bias);
+ return VARS_TO_STR7(type, ne, mask, m_prec, nr23, scale, max_bias);
}
// the 1024 test with bias occasionally fails:
std::array<int64_t, 4> ne = {10, 5, 4, 3},
bool mask = false,
ggml_type m_prec = GGML_TYPE_F32,
+ std::array<int64_t, 2> nr23 = {1, 1},
float scale = 1.0f,
float max_bias = 0.0f)
- : type(type), ne(ne), mask(mask), m_prec(m_prec), scale(scale), max_bias(max_bias) {}
+ : type(type), ne(ne), mask(mask), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
- ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+ ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * mask = nullptr;
if (this->mask) {
- mask = ggml_new_tensor_2d(ctx, m_prec, ne[0], ne[1]);
+ mask = ggml_new_tensor_4d(ctx, m_prec, ne[0], ne[1], ne[2], ne[3]);
ggml_set_name(mask, "mask");
}
}
};
-// GGML_OP_ROLL
-struct test_roll : public test_case {
- const int shift0;
- const int shift1;
- const int shift3;
- const int shift4;
-
- std::string vars() override {
- return VARS_TO_STR4(shift0, shift1, shift3, shift4);
- }
-
- test_roll(int shift0 = 3, int shift1 = -2, int shift3 = 1, int shift4 = -1)
- : shift0(shift0), shift1(shift1), shift3(shift3), shift4(shift4) {}
-
- ggml_tensor * build_graph(ggml_context * ctx) override {
- int64_t ne[4] = {10, 5, 4, 3};
- ggml_tensor * a = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
- ggml_set_name(a, "a");
-
- ggml_tensor * out = ggml_roll(ctx, a, shift0, shift1, shift3, shift4);
- ggml_set_name(out, "out");
-
- return out;
- }
-};
-
// GGML_OP_ARANGE
struct test_arange : public test_case {
const ggml_type type;
const int64_t hsk; // K head size
const int64_t hsv; // V head size
const int64_t nh; // num heads
- const int64_t nr; // repeat in Q, tests for grouped-query attention
+ const std::array<int64_t, 2> nr23; // repeat in dim 2 and 3, tests for grouped-query attention
const int64_t kv; // kv size
const int64_t nb; // batch size
std::array<int32_t, 4> permute;
std::string vars() override {
- return VARS_TO_STR12(hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
+ return VARS_TO_STR12(hsk, hsv, nh, nr23, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
}
double max_nmse_err() override {
GGML_UNUSED(t);
// Just counting matmul costs:
// Q*K^T is nb x hsk x kv, P*V is nb x kv x hsv, per head
- return 2 * nh*nr * nb * (hsk + hsv) * kv;
+ return (2 * nh*nr23[0] * nb * (hsk + hsv) * kv)*nr23[1];
}
- test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
+ test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, std::array<int64_t, 2> nr23 = {1, 1}, int64_t kv = 96, int64_t nb = 8,
bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3})
- : hsk(hsk), hsv(hsv), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
+ : hsk(hsk), hsv(hsv), nh(nh), nr23(nr23), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV));
return t;
};
- ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr, 1);
+ ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0], nr23[1]);
ggml_set_name(q, "q");
- ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, 1);
+ ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, nr23[1]);
ggml_set_name(k, "k");
- ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, 1);
+ ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1]);
ggml_set_name(v, "v");
ggml_tensor * m = nullptr;
if (mask) {
- m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1);
+ m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[1], 1);
ggml_set_name(m, "m");
}
for (int64_t ne1 : {16, 1024}) {
if (mask) {
for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, scale, max_bias));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, scale, max_bias));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias));
+
+ if (ne0 <= 32 && ne1 <= 32) {
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, {3, 1}, scale, max_bias));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {2, 3}, scale, max_bias));
+ }
}
} else {
/* The precision of mask here doesn't matter as boolean mask is false */
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, GGML_TYPE_F32, scale, max_bias));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, GGML_TYPE_F32, scale, max_bias));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, GGML_TYPE_F32, {1, 1}, scale, max_bias));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, GGML_TYPE_F32, {1, 1}, scale, max_bias));
}
}
}
}
}
}
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, GGML_TYPE_F32, 0.1f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 8.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 8.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));
for (float max_bias : {0.0f, 8.0f}) {
for (float scale : {1.0f, 0.1f}) {
test_cases.emplace_back(new test_acc());
test_cases.emplace_back(new test_pad());
test_cases.emplace_back(new test_pad_reflect_1d());
- test_cases.emplace_back(new test_roll());
test_cases.emplace_back(new test_arange());
test_cases.emplace_back(new test_timestep_embedding());
test_cases.emplace_back(new test_leaky_relu());
for (float logit_softcap : {0.0f, 10.0f}) {
if (hsk != 128 && logit_softcap != 0.0f) continue;
for (int nh : { 4, }) {
- for (int nr : { 1, 4, 16 }) {
- if (nr == 16 && hsk != 128) continue;
- for (int kv : { 512, 1024, }) {
- if (nr != 1 && kv != 512) continue;
- for (int nb : { 1, 3, 32, 35, }) {
- for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
- if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
- for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
- test_cases.emplace_back(new test_flash_attn_ext(
- hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
- // run fewer test cases permuted
- if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
+ for (int nr3 : { 1, 3, }) {
+ if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes
+ for (int nr2 : { 1, 4, 16 }) {
+ if (nr2 == 16 && hsk != 128) continue;
+ for (int kv : { 512, 1024, }) {
+ if (nr2 != 1 && kv != 512) continue;
+ for (int nb : { 1, 3, 32, 35, }) {
+ for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
+ if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
+ for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
test_cases.emplace_back(new test_flash_attn_ext(
- hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
+ hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
+ // run fewer test cases permuted
+ if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
+ test_cases.emplace_back(new test_flash_attn_ext(
+ hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
+ }
}
}
}
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1}));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
for (int kv : { 4096, 8192, 16384, }) {
for (int hs : { 64, 128, }) {
for (int nr : { 1, 4, }) {
- test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, nr, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
+ test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, {nr, 1}, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
}
}
}