std::vector<float> data;
};
+struct sam_params {
+ int32_t seed = -1; // RNG seed
+ int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
+
+ std::string model = "models/sam-vit-b/ggml-model-f16.bin"; // model path
+ std::string fname_inp = "img.jpg";
+ std::string fname_out = "img.out";
+ float mask_threshold = 0.f;
+ float iou_threshold = 0.88f;
+ float stability_score_threshold = 0.95f;
+ float stability_score_offset = 1.0f;
+ float eps = 1e-6f;
+ float eps_decoder_transformer = 1e-5f;
+ sam_point pt = { 414.375f, 162.796875f, };
+};
+
void print_t_f32(const char* title, struct ggml_tensor * t, int n = 10) {
printf("%s\n", title);
float * data = (float *)t->data;
}
// load the model's weights from a file
-bool sam_model_load(const std::string & fname, sam_model & model) {
- fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
+bool sam_model_load(const sam_params & params, sam_model & model) {
+ fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, params.model.c_str());
- auto fin = std::ifstream(fname, std::ios::binary);
+ auto fin = std::ifstream(params.model, std::ios::binary);
if (!fin) {
- fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
+ fprintf(stderr, "%s: failed to open '%s'\n", __func__, params.model.c_str());
return false;
}
uint32_t magic;
fin.read((char *) &magic, sizeof(magic));
if (magic != 0x67676d6c) {
- fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
+ fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, params.model.c_str());
return false;
}
}
// load hparams
{
+ // Override defaults with user choices
+ model.hparams.mask_threshold = params.mask_threshold;
+ model.hparams.iou_threshold = params.iou_threshold;
+ model.hparams.stability_score_threshold = params.stability_score_threshold;
+ model.hparams.stability_score_offset = params.stability_score_offset;
+ model.hparams.eps = params.eps;
+ model.hparams.eps_decoder_transformer = params.eps_decoder_transformer;
+
auto & hparams = model.hparams;
fin.read((char *) &hparams.n_enc_state, sizeof(hparams.n_enc_state));
printf("%s: qntvr = %d\n", __func__, qntvr);
hparams.ftype %= GGML_QNT_VERSION_FACTOR;
+
}
// for the big tensors, we have the option to store the data in 16-bit floats or quantized
ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
if (wtype == GGML_TYPE_COUNT) {
fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
- __func__, fname.c_str(), model.hparams.ftype);
+ __func__, params.model.c_str(), model.hparams.ftype);
return false;
}
return true;
}
-bool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state & state) {
+bool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state & state, const std::string & fname) {
if (state.low_res_masks->ne[2] == 0) return true;
if (state.low_res_masks->ne[2] != state.iou_predictions->ne[0]) {
printf("Error: number of masks (%d) does not match number of iou predictions (%d)\n", (int)state.low_res_masks->ne[2], (int)state.iou_predictions->ne[0]);
printf("Mask %d: iou = %f, stability_score = %f, bbox (%d, %d), (%d, %d)\n",
i, iou_data[i], stability_score, min_ix, max_ix, min_iy, max_iy);
- std::string filename = "mask_out_" + std::to_string(i) + ".png";
+ std::string filename = fname + std::to_string(i) + ".png";
if (!stbi_write_png(filename.c_str(), res.nx, res.ny, 1, res.data.data(), res.nx)) {
printf("%s: failed to write mask %s\n", __func__, filename.c_str());
return false;
prompt_encoder_result enc_res = sam_encode_prompt(model, ctx0, gf, state, nx, ny, point);
if (!enc_res.embd_prompt_sparse || !enc_res.embd_prompt_dense) {
- fprintf(stderr, "%s: failed to encode prompt\n", __func__);
+ fprintf(stderr, "%s: failed to encode prompt (%f, %f)\n", __func__, point.x, point.y);
return {};
}
return gf;
}
-struct sam_params {
- int32_t seed = -1; // RNG seed
- int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
-
- std::string model = "models/sam-vit-b/ggml-model-f16.bin"; // model path
- std::string fname_inp = "img.jpg";
- std::string fname_out = "img.out";
-};
void sam_print_usage(int argc, char ** argv, const sam_params & params) {
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, " -i FNAME, --inp FNAME\n");
fprintf(stderr, " input file (default: %s)\n", params.fname_inp.c_str());
fprintf(stderr, " -o FNAME, --out FNAME\n");
- fprintf(stderr, " output file (default: %s)\n", params.fname_out.c_str());
+ fprintf(stderr, " mask file name prefix (default: %s)\n", params.fname_out.c_str());
+ fprintf(stderr, "SAM hyperparameters:\n");
+ fprintf(stderr, " -mt FLOAT, --mask-threshold\n");
+ fprintf(stderr, " mask threshold (default: %f)\n", params.mask_threshold);
+ fprintf(stderr, " -it FLOAT, --iou-threshold\n");
+ fprintf(stderr, " iou threshold (default: %f)\n", params.iou_threshold);
+ fprintf(stderr, " -st FLOAT, --score-threshold\n");
+ fprintf(stderr, " score threshold (default: %f)\n", params.stability_score_threshold);
+ fprintf(stderr, " -so FLOAT, --score-offset\n");
+ fprintf(stderr, " score offset (default: %f)\n", params.stability_score_offset);
+ fprintf(stderr, " -e FLOAT, --epsilon\n");
+ fprintf(stderr, " epsilon (default: %f)\n", params.eps);
+ fprintf(stderr, " -ed FLOAT, --epsilon-decoder-transformer\n");
+ fprintf(stderr, " epsilon decoder transformer (default: %f)\n", params.eps_decoder_transformer);
+ fprintf(stderr, "SAM prompt:\n");
+ fprintf(stderr, " -p TUPLE, --point-prompt\n");
+ fprintf(stderr, " point to be used as prompt for SAM (default: %f,%f). Must be in a format FLOAT,FLOAT \n", params.pt.x, params.pt.y);
fprintf(stderr, "\n");
}
params.fname_inp = argv[++i];
} else if (arg == "-o" || arg == "--out") {
params.fname_out = argv[++i];
+ } else if (arg == "-mt" || arg == "--mask-threshold") {
+ params.mask_threshold = std::stof(argv[++i]);
+ } else if (arg == "-it" || arg == "--iou-threshold") {
+ params.iou_threshold = std::stof(argv[++i]);
+ } else if (arg == "-st" || arg == "--score-threshold") {
+ params.stability_score_threshold = std::stof(argv[++i]);
+ } else if (arg == "-so" || arg == "--score-offset") {
+ params.stability_score_offset = std::stof(argv[++i]);
+ } else if (arg == "-e" || arg == "--epsilon") {
+ params.eps = std::stof(argv[++i]);
+ } else if (arg == "-ed" || arg == "--epsilon-decoder-transformer") {
+ params.eps_decoder_transformer = std::stof(argv[++i]);
+ } else if (arg == "-p" || arg == "--point-prompt") {
+ // TODO multiple points per model invocation
+ char* point = argv[++i];
+
+ char* coord = strtok(point, ",");
+ if (!coord){
+ fprintf(stderr, "Error while parsing prompt!\n");
+ exit(1);
+ }
+ params.pt.x = std::stof(coord);
+ coord = strtok(NULL, ",");
+ if (!coord){
+ fprintf(stderr, "Error while parsing prompt!\n");
+ exit(1);
+ }
+ params.pt.y = std::stof(coord);
} else if (arg == "-h" || arg == "--help") {
sam_print_usage(argc, argv, params);
exit(0);
{
const int64_t t_start_us = ggml_time_us();
- if (!sam_model_load(params.model, model)) {
+ if (!sam_model_load(params, model)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
return 1;
}
state.buf_compute_fast.resize(ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead());
state.allocr = ggml_allocr_new_measure(tensor_alignment);
- // TODO: user input
- const sam_point pt = { 414.375f, 162.796875f, };
+ // TODO: more varied prompts
+ fprintf(stderr, "prompt: (%f, %f)\n", params.pt.x, params.pt.y);
+
// measure memory requirements for the graph
- struct ggml_cgraph * gf_measure = sam_build_fast_graph(model, state, img0.nx, img0.ny, pt);
+ struct ggml_cgraph * gf_measure = sam_build_fast_graph(model, state, img0.nx, img0.ny, params.pt);
if (!gf_measure) {
fprintf(stderr, "%s: failed to build fast graph to measure\n", __func__);
return 1;
// compute the graph with the measured exact memory requirements from above
ggml_allocr_reset(state.allocr);
- struct ggml_cgraph * gf = sam_build_fast_graph(model, state, img0.nx, img0.ny, pt);
+ struct ggml_cgraph * gf = sam_build_fast_graph(model, state, img0.nx, img0.ny, params.pt);
if (!gf) {
fprintf(stderr, "%s: failed to build fast graph\n", __func__);
return 1;
state.allocr = NULL;
}
- if (!sam_write_masks(model.hparams, img0.nx, img0.ny, state)) {
+ if (!sam_write_masks(model.hparams, img0.nx, img0.ny, state, params.fname_out)) {
fprintf(stderr, "%s: failed to write masks\n", __func__);
return 1;
}