case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_ADD_ID:
- return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ACC:
+ return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_REPEAT:
case GGML_OP_CONV_TRANSPOSE_1D:
return true;
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(op->type == GGML_TYPE_F32);
- GGML_ASSERT(ggml_is_contiguous(op->src[0]));
- GGML_ASSERT(ggml_is_contiguous(op->src[1]));
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
const size_t pnb1 = ((const int32_t *) op->op_params)[0];
const size_t pnb2 = ((const int32_t *) op->op_params)[1];
}
ggml_metal_kargs_bin args = {
- /*.ne00 =*/ ne00,
- /*.ne01 =*/ ne01,
- /*.ne02 =*/ ne02,
- /*.ne03 =*/ ne03,
+ /*.ne00 =*/ ne10,
+ /*.ne01 =*/ ne11,
+ /*.ne02 =*/ ne12,
+ /*.ne03 =*/ ne13,
/*.nb00 =*/ nb00,
/*.nb01 =*/ pnb1,
/*.nb02 =*/ pnb2,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
- /*.ne0 =*/ ne0,
- /*.ne1 =*/ ne1,
- /*.ne2 =*/ ne2,
- /*.ne3 =*/ ne3,
+ /*.ne0 =*/ ne10,
+ /*.ne1 =*/ ne11,
+ /*.ne2 =*/ ne12,
+ /*.ne3 =*/ ne13,
/*.nb0 =*/ nb0,
/*.nb1 =*/ pnb1,
/*.nb2 =*/ pnb2,
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
- const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
+ const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+
+ int nth = 1;
+
+ while (2*nth < args.ne0 && nth < nth_max) {
+ nth *= 2;
+ }
ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);