* check cuda argsort limits and add test
* add metal check
case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_POOL_2D:
case GGML_OP_SUM:
- case GGML_OP_ARGSORT:
case GGML_OP_ACC:
return true;
+ case GGML_OP_ARGSORT:
+ // TODO: Support arbitrary column width
+ return op->src[0]->ne[0] <= 1024;
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_GROUP_NORM:
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_TIMESTEP_EMBEDDING:
- case GGML_OP_ARGSORT:
case GGML_OP_LEAKY_RELU:
return op->src[0]->type == GGML_TYPE_F32;
+ case GGML_OP_ARGSORT:
+ // TODO: Support arbitrary column width
+ return op->src[0]->ne[0] <= 1024;
case GGML_OP_ARANGE:
return true;
case GGML_OP_FLASH_ATTN_EXT: