case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_POOL_2D:
case GGML_OP_SUM:
- case GGML_OP_ARGSORT:
case GGML_OP_ACC:
return true;
+ case GGML_OP_ARGSORT:
+ // TODO: Support arbitrary column width
+ return op->src[0]->ne[0] <= 1024;
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_GROUP_NORM:
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_TIMESTEP_EMBEDDING:
- case GGML_OP_ARGSORT:
case GGML_OP_LEAKY_RELU:
return op->src[0]->type == GGML_TYPE_F32;
+ case GGML_OP_ARGSORT:
+ // TODO: Support arbitrary column width
+ return op->src[0]->ne[0] <= 1024;
case GGML_OP_ARANGE:
return true;
case GGML_OP_FLASH_ATTN_EXT:
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order));
+ test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // bailingmoe2 (group selection)
}
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {