// enter here only when capturing in order to wait for all computation to finish
// otherwise, we leave the graph to compute asynchronously
- if (!use_capture && ctx->capture_started) {
+ if (use_capture && ctx->capture_started) {
// wait for completion and check status of each command buffer
// needed to detect if the device ran out-of-memory for example (#1881)
{
[ctx->capture_scope endScope];
[[MTLCaptureManager sharedCaptureManager] stopCapture];
+
+ ctx->capture_started = false;
}
}
const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0);
+ const bool is_cb = op->src[0]->ne[0] != op->src[1]->ne[0];
const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536;
snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : "");
- snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d", base, op_num, n_fuse, is_rb);
+ snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d_cb=%d", base, op_num, n_fuse, is_rb, is_cb);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1);
ggml_metal_cv_set_bool (cv, is_rb, FC_BIN + 2);
+ ggml_metal_cv_set_bool (cv, is_cb, FC_BIN + 3);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
if (pipeline.cnt) {
- const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
-
- ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
+ ggml_metal_encoder_dispatch_threadgroups(enc, args.ne0, ggml_nrows(op), 1, 1, 1, 1);
} else {
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]];
+constant bool FC_bin_cb [[function_constant(FC_BIN + 3)]];
template <typename T0, typename T1, typename T>
kernel void kernel_bin_fuse_impl(
#define FC_OP FC_bin_op
#define FC_F FC_bin_f
#define FC_RB FC_bin_rb
+#define FC_CB FC_bin_cb
if (FC_RB) {
// row broadcast
- const uint i0 = tgpig.x;
- const uint i1 = i0%args.ne10;
+ const uint i0 = tgpig.y*args.ne00 + tgpig.x;
+ const uint i1 = FC_CB ? tgpig.x%args.ne10 : tgpig.x;
device const T0 * src0_row = (device const T0 *) (src0);
device T * dst_row = (device T *) (dst);
device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- const int i10 = i0%args.ne10;
+ const int i10 = FC_CB ? i0%args.ne10 : i0;
if (FC_OP == 0) {
dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
}
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- const int i10 = i0%args.ne10;
+ const int i10 = FC_CB ? i0%args.ne10 : i0;
T res = src0_ptr[i0];
#undef FC_OP
#undef FC_F
#undef FC_RB
+#undef FC_CB
}
typedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t;