GGML_OP_ARANGE,
GGML_OP_TIMESTEP_EMBEDDING,
GGML_OP_ARGSORT,
+ GGML_OP_TOP_K,
GGML_OP_LEAKY_RELU,
GGML_OP_TRI,
GGML_OP_FILL,
struct ggml_tensor * a,
enum ggml_sort_order order);
- GGML_API struct ggml_tensor * ggml_arange(
+ // similar to ggml_top_k but implemented as `argsort` + `view`
+ GGML_API struct ggml_tensor * ggml_argsort_top_k(
struct ggml_context * ctx,
- float start,
- float stop,
- float step);
+ struct ggml_tensor * a,
+ int k);
// top k elements per row
+ // note: the resulting top k indices are in no particular order
GGML_API struct ggml_tensor * ggml_top_k(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k);
+ GGML_API struct ggml_tensor * ggml_arange(
+ struct ggml_context * ctx,
+ float start,
+ float stop,
+ float step);
+
#define GGML_KQ_MASK_PAD 64
// q: [n_embd_k, n_batch, n_head, ne3 ]
{
ggml_compute_forward_argsort(params, tensor);
} break;
+ case GGML_OP_TOP_K:
+ {
+ ggml_compute_forward_top_k(params, tensor);
+ } break;
case GGML_OP_LEAKY_RELU:
{
ggml_compute_forward_leaky_relu(params, tensor);
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT:
+ case GGML_OP_TOP_K:
case GGML_OP_FLASH_ATTN_EXT:
case GGML_OP_FLASH_ATTN_BACK:
case GGML_OP_SSM_CONV:
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
} break;
+ case GGML_OP_TOP_K:
+ {
+ cur += sizeof(int32_t)*node->src[0]->ne[0]*n_tasks;
+ } break;
case GGML_OP_FLASH_ATTN_EXT:
{
const int64_t ne10 = node->src[1]->ne[0]; // DK
// ggml_compute_forward_argsort
template<enum ggml_sort_order order>
-struct argsort_cmp {
+struct cmp_argsort {
const float * data;
bool operator()(int32_t a, int32_t b) const {
if constexpr (order == GGML_SORT_ORDER_ASC) {
switch (order) {
case GGML_SORT_ORDER_ASC:
- std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_ASC>{src_data});
+ std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_ASC>{src_data});
break;
case GGML_SORT_ORDER_DESC:
- std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_DESC>{src_data});
+ std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_DESC>{src_data});
break;
default:
}
}
+// ggml_compute_forward_top_k
+
+struct cmp_top_k {
+ const float * data;
+ bool operator()(int32_t a, int32_t b) const {
+ return data[a] > data[b];
+ }
+};
+
+static void ggml_compute_forward_top_k_f32(
+ const ggml_compute_params * params,
+ ggml_tensor * dst) {
+
+ const ggml_tensor * src0 = dst->src[0];
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ GGML_ASSERT(nb0 == sizeof(float));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int64_t nr = ggml_nrows(src0);
+
+ const int top_k = ne0;
+
+ int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
+
+ for (int64_t i = ith; i < nr; i += nth) {
+ const float * src_data = (float *)((char *) src0->data + i*nb01);
+
+ for (int64_t j = 0; j < ne00; j++) {
+ tmp[j] = j;
+ }
+
+ std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
+
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
+
+ std::copy(tmp, tmp + top_k, dst_data);
+
+ // emphasize that the order is not important
+ if (top_k > 1) {
+ std::swap(dst_data[0], dst_data[1]);
+ }
+ }
+}
+
+void ggml_compute_forward_top_k(
+ const ggml_compute_params * params,
+ ggml_tensor * dst) {
+
+ const ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_top_k_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ABORT("fatal error");
+ }
+ }
+}
+
// ggml_compute_forward_flash_attn_ext
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_top_k(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst);
return res;
}
+// note: reuse the argsort kernel for top_k
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) {
+ assert(op->op == GGML_OP_TOP_K);
+
+ char base[256];
+ char name[256];
+
+ // note: the top_k kernel is always descending order
+ ggml_sort_order order = GGML_SORT_ORDER_DESC;
+
+ const char * order_str = "undefined";
+ switch (order) {
+ case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
+ case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
+ default: GGML_ABORT("fatal error");
+ };
+
+ snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
+ snprintf(name, 256, "%s", base);
+
+ ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
+ if (res) {
+ return res;
+ }
+
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+
+ return res;
+}
+
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
+ assert(op->op == GGML_OP_TOP_K);
+
+ char base[256];
+ char name[256];
+
+ ggml_sort_order order = GGML_SORT_ORDER_DESC;
+
+ const char * order_str = "undefined";
+ switch (order) {
+ case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
+ case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
+ default: GGML_ABORT("fatal error");
+ };
+
+ snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
+ snprintf(name, 256, "%s", base);
+
+ ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
+ if (res) {
+ return res;
+ }
+
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+
+ return res;
+}
+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
case GGML_OP_LEAKY_RELU:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ARGSORT:
+ case GGML_OP_TOP_K:
case GGML_OP_ARANGE:
return true;
case GGML_OP_FLASH_ATTN_EXT:
} ggml_metal_kargs_leaky_relu;
typedef struct {
- int64_t ne00;
- int64_t ne01;
- int64_t ne02;
- int64_t ne03;
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
+ int32_t ne0;
+ int32_t ne1;
+ int32_t ne2;
+ int32_t ne3;
+ int32_t top_k;
} ggml_metal_kargs_argsort;
typedef struct {
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
+ int32_t ne0;
+ int32_t ne1;
+ int32_t ne2;
+ int32_t ne3;
+ int32_t top_k;
int32_t len;
} ggml_metal_kargs_argsort_merge;
{
n_fuse = ggml_metal_op_argsort(ctx, idx);
} break;
+ case GGML_OP_TOP_K:
+ {
+ n_fuse = ggml_metal_op_top_k(ctx, idx);
+ } break;
case GGML_OP_LEAKY_RELU:
{
n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
}
ggml_metal_kargs_argsort args = {
- /*.ne00 =*/ ne00,
- /*.ne01 =*/ ne01,
- /*.ne02 =*/ ne02,
- /*.ne03 =*/ ne03,
- /*.nb00 =*/ nb00,
- /*.nb01 =*/ nb01,
- /*.nb02 =*/ nb02,
- /*.nb03 =*/ nb03,
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.top_k =*/ nth,
};
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_op_concurrency_reset(ctx);
ggml_metal_kargs_argsort_merge args_merge = {
- .ne00 = ne00,
- .ne01 = ne01,
- .ne02 = ne02,
- .ne03 = ne03,
- .nb00 = nb00,
- .nb01 = nb01,
- .nb02 = nb02,
- .nb03 = nb03,
- .len = len,
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.top_k =*/ ne00,
+ /*.len =*/ len,
};
// merges per row
return 1;
}
+int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
+
+ // bitonic sort requires the number of elements to be power of 2
+ int nth = 1;
+ while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+ nth *= 2;
+ }
+
+ // blocks per row
+ const int npr = (ne00 + nth - 1)/nth;
+
+ const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
+
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
+
+ ggml_metal_buffer_id bid_tmp = bid_dst;
+ bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]);
+
+ if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
+ std::swap(bid_dst, bid_tmp);
+ }
+
+ const int top_k = ne0;
+
+ ggml_metal_kargs_argsort args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices
+ };
+
+ if (npr > 1) {
+ args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k);
+ }
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
+
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
+
+ ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
+
+ int len = args.top_k;
+
+ while (len < args.ne0) {
+ ggml_metal_op_concurrency_reset(ctx);
+
+ // merges per row
+ const int nm = (args.ne0 + 2*len - 1) / (2*len);
+
+ const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge)));
+
+ ggml_metal_kargs_argsort_merge args_merge = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne0 =*/ args.ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements
+ /*.len =*/ len,
+ };
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
+ ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
+
+ std::swap(bid_dst, bid_tmp);
+
+ len <<= 1;
+ }
+
+ return 1;
+}
+
int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
{
res *= 2;
} break;
+ case GGML_OP_TOP_K:
+ {
+ res = 2*sizeof(int32_t)*ggml_nelements(tensor->src[0]);
+ } break;
default:
break;
}
ushort3 ntg[[threads_per_threadgroup]]) {
// bitonic sort
const int col = tpitg[0];
+ const int ib = tgpig[0] / args.ne01;
- const int i00 = (tgpig[0]/args.ne01)*ntg.x;
- const int i01 = tgpig[0]%args.ne01;
- const int i02 = tgpig[1];
- const int i03 = tgpig[2];
+ const int i00 = ib*ntg.x;
+ const int i01 = tgpig[0] % args.ne01;
+ const int i02 = tgpig[1];
+ const int i03 = tgpig[2];
device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
}
}
+ const int64_t i0 = ib*args.top_k;
+
// copy the result to dst without the padding
- if (i00 + col < args.ne00) {
- dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03;
+ if (i0 + col < args.ne0 && col < args.top_k) {
+ dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03;
dst[col] = shmem_i32[col];
}
const int start = im * (2 * args.len);
- const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start)));
- const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len)));
+ const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
+ const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
const int total = len0 + len1;
device const int32_t * tmp0 = tmp + start
- + i01*args.ne00
- + i02*args.ne00*args.ne01
- + i03*args.ne00*args.ne01*args.ne02;
+ + i01*args.ne0
+ + i02*args.ne0*args.ne01
+ + i03*args.ne0*args.ne01*args.ne02;
device const int32_t * tmp1 = tmp0 + args.len;
dst += start
- + i01*args.ne00
- + i02*args.ne00*args.ne01
- + i03*args.ne00*args.ne01*args.ne02;
+ + i01*args.top_k
+ + i02*args.top_k*args.ne01
+ + i03*args.top_k*args.ne01*args.ne02;
device const float * src0_row = (device const float *)(src0
+ args.nb01*i01
const int chunk = (total + ntg.x - 1) / ntg.x;
const int k0 = tpitg.x * chunk;
- const int k1 = min(k0 + chunk, total);
+ const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
+
+ if (k0 >= args.top_k) {
+ return;
+ }
if (k0 >= total) {
return;
"ARANGE",
"TIMESTEP_EMBEDDING",
"ARGSORT",
+ "TOP_K",
"LEAKY_RELU",
"TRI",
"FILL",
"GLU",
};
-static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
+static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
"arange(start, stop, step)",
"timestep_embedding(timesteps, dim, max_period)",
"argsort(x)",
+ "top_k(x)",
"leaky_relu(x)",
"tri(x)",
"fill(x, c)",
"glu(x)",
};
-static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
+static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
return result;
}
-// ggml_arange
-
-struct ggml_tensor * ggml_arange(
- struct ggml_context * ctx,
- float start,
- float stop,
- float step) {
- GGML_ASSERT(stop > start);
-
- const int64_t steps = (int64_t) ceilf((stop - start) / step);
-
- struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
-
- ggml_set_op_params_f32(result, 0, start);
- ggml_set_op_params_f32(result, 1, stop);
- ggml_set_op_params_f32(result, 2, step);
-
- result->op = GGML_OP_ARANGE;
-
- return result;
-}
-
// ggml_timestep_embedding
struct ggml_tensor * ggml_timestep_embedding(
struct ggml_tensor * a,
enum ggml_sort_order order) {
GGML_ASSERT(a->ne[0] <= INT32_MAX);
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
ggml_set_op_params_i32(result, 0, (int32_t) order);
return result;
}
-// ggml_top_k
+// ggml_argsort_top_k
-struct ggml_tensor * ggml_top_k(
+struct ggml_tensor * ggml_argsort_top_k(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k) {
return result;
}
+// ggml_top_k
+
+struct ggml_tensor * ggml_top_k(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int k) {
+ GGML_ASSERT(a->ne[0] >= k);
+
+ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_I32, k, a->ne[1], a->ne[2], a->ne[3]);
+
+ result->op = GGML_OP_TOP_K;
+ result->src[0] = a;
+
+ return result;
+}
+
+// ggml_arange
+
+struct ggml_tensor * ggml_arange(
+ struct ggml_context * ctx,
+ float start,
+ float stop,
+ float step) {
+ GGML_ASSERT(stop > start);
+
+ const int64_t steps = (int64_t) ceilf((stop - start) / step);
+
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
+
+ ggml_set_op_params_f32(result, 0, start);
+ ggml_set_op_params_f32(result, 1, stop);
+ ggml_set_op_params_f32(result, 2, step);
+
+ result->op = GGML_OP_ARANGE;
+
+ return result;
+}
+
// ggml_flash_attn_ext
struct ggml_tensor * ggml_flash_attn_ext(