]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
sam : passing parameters and simple prompt (#598)
authorJiří Podivín <redacted>
Thu, 2 Nov 2023 19:28:11 +0000 (20:28 +0100)
committerGitHub <redacted>
Thu, 2 Nov 2023 19:28:11 +0000 (21:28 +0200)
- most of the model hyperparameters can now be set on CLI
- user can define their own mask prefix
- user can define their own point prompt, although just one

Signed-off-by: Jiri Podivin <redacted>
examples/sam/main.cpp

index 7c4130fddace7adb53a16a0effa529529bc3298f..38d5e2734a88286cd13155c48503812567d52261 100644 (file)
@@ -296,6 +296,22 @@ struct sam_image_f32 {
     std::vector<float> data;
 };
 
+struct sam_params {
+    int32_t seed      = -1; // RNG seed
+    int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
+
+    std::string model     = "models/sam-vit-b/ggml-model-f16.bin"; // model path
+    std::string fname_inp = "img.jpg";
+    std::string fname_out = "img.out";
+    float   mask_threshold            = 0.f;
+    float   iou_threshold             = 0.88f;
+    float   stability_score_threshold = 0.95f;
+    float   stability_score_offset    = 1.0f;
+    float   eps                       = 1e-6f;
+    float   eps_decoder_transformer   = 1e-5f;
+    sam_point pt = { 414.375f, 162.796875f, };
+};
+
 void print_t_f32(const char* title, struct ggml_tensor * t, int n = 10) {
     printf("%s\n", title);
     float * data = (float *)t->data;
@@ -469,12 +485,12 @@ bool sam_image_preprocess(const sam_image_u8 & img, sam_image_f32 & res) {
 }
 
 // load the model's weights from a file
-bool sam_model_load(const std::string & fname, sam_model & model) {
-    fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
+bool sam_model_load(const sam_params & params, sam_model & model) {
+    fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, params.model.c_str());
 
-    auto fin = std::ifstream(fname, std::ios::binary);
+    auto fin = std::ifstream(params.model, std::ios::binary);
     if (!fin) {
-        fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
+        fprintf(stderr, "%s: failed to open '%s'\n", __func__, params.model.c_str());
         return false;
     }
 
@@ -483,13 +499,21 @@ bool sam_model_load(const std::string & fname, sam_model & model) {
         uint32_t magic;
         fin.read((char *) &magic, sizeof(magic));
         if (magic != 0x67676d6c) {
-            fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
+            fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, params.model.c_str());
             return false;
         }
     }
 
     // load hparams
     {
+        // Override defaults with user choices
+        model.hparams.mask_threshold            = params.mask_threshold;
+        model.hparams.iou_threshold             = params.iou_threshold;
+        model.hparams.stability_score_threshold = params.stability_score_threshold;
+        model.hparams.stability_score_offset    = params.stability_score_offset;
+        model.hparams.eps                       = params.eps;
+        model.hparams.eps_decoder_transformer   = params.eps_decoder_transformer;
+
         auto & hparams = model.hparams;
 
         fin.read((char *) &hparams.n_enc_state,     sizeof(hparams.n_enc_state));
@@ -510,6 +534,7 @@ bool sam_model_load(const std::string & fname, sam_model & model) {
         printf("%s: qntvr            = %d\n", __func__, qntvr);
 
         hparams.ftype %= GGML_QNT_VERSION_FACTOR;
+
     }
 
     // for the big tensors, we have the option to store the data in 16-bit floats or quantized
@@ -517,7 +542,7 @@ bool sam_model_load(const std::string & fname, sam_model & model) {
     ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
     if (wtype == GGML_TYPE_COUNT) {
         fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
-                __func__, fname.c_str(), model.hparams.ftype);
+                __func__, params.model.c_str(), model.hparams.ftype);
         return false;
     }
 
@@ -1791,7 +1816,7 @@ bool sam_decode_mask(
     return true;
 }
 
-bool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state & state) {
+bool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state & state, const std::string & fname) {
     if (state.low_res_masks->ne[2] == 0) return true;
     if (state.low_res_masks->ne[2] != state.iou_predictions->ne[0]) {
         printf("Error: number of masks (%d) does not match number of iou predictions (%d)\n", (int)state.low_res_masks->ne[2], (int)state.iou_predictions->ne[0]);
@@ -1938,7 +1963,7 @@ bool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state
         printf("Mask %d: iou = %f, stability_score = %f, bbox (%d, %d), (%d, %d)\n",
                 i, iou_data[i], stability_score, min_ix, max_ix, min_iy, max_iy);
 
-        std::string filename = "mask_out_" + std::to_string(i) + ".png";
+        std::string filename = fname + std::to_string(i) + ".png";
         if (!stbi_write_png(filename.c_str(), res.nx, res.ny, 1, res.data.data(), res.nx)) {
             printf("%s: failed to write mask %s\n", __func__, filename.c_str());
             return false;
@@ -1967,7 +1992,7 @@ struct ggml_cgraph  * sam_build_fast_graph(
 
     prompt_encoder_result enc_res = sam_encode_prompt(model, ctx0, gf, state, nx, ny, point);
     if (!enc_res.embd_prompt_sparse || !enc_res.embd_prompt_dense) {
-        fprintf(stderr, "%s: failed to encode prompt\n", __func__);
+        fprintf(stderr, "%s: failed to encode prompt (%f, %f)\n", __func__, point.x, point.y);
         return {};
     }
 
@@ -1986,14 +2011,6 @@ struct ggml_cgraph  * sam_build_fast_graph(
 
     return gf;
 }
-struct sam_params {
-    int32_t seed      = -1; // RNG seed
-    int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
-
-    std::string model     = "models/sam-vit-b/ggml-model-f16.bin"; // model path
-    std::string fname_inp = "img.jpg";
-    std::string fname_out = "img.out";
-};
 
 void sam_print_usage(int argc, char ** argv, const sam_params & params) {
     fprintf(stderr, "usage: %s [options]\n", argv[0]);
@@ -2007,7 +2024,23 @@ void sam_print_usage(int argc, char ** argv, const sam_params & params) {
     fprintf(stderr, "  -i FNAME, --inp FNAME\n");
     fprintf(stderr, "                        input file (default: %s)\n", params.fname_inp.c_str());
     fprintf(stderr, "  -o FNAME, --out FNAME\n");
-    fprintf(stderr, "                        output file (default: %s)\n", params.fname_out.c_str());
+    fprintf(stderr, "                        mask file name prefix (default: %s)\n", params.fname_out.c_str());
+    fprintf(stderr, "SAM hyperparameters:\n");
+    fprintf(stderr, "  -mt FLOAT, --mask-threshold\n");
+    fprintf(stderr, "                        mask threshold (default: %f)\n", params.mask_threshold);
+    fprintf(stderr, "  -it FLOAT, --iou-threshold\n");
+    fprintf(stderr, "                        iou threshold (default: %f)\n", params.iou_threshold);
+    fprintf(stderr, "  -st FLOAT, --score-threshold\n");
+    fprintf(stderr, "                        score threshold (default: %f)\n", params.stability_score_threshold);
+    fprintf(stderr, "  -so FLOAT, --score-offset\n");
+    fprintf(stderr, "                        score offset (default: %f)\n", params.stability_score_offset);
+    fprintf(stderr, "  -e FLOAT, --epsilon\n");
+    fprintf(stderr, "                        epsilon (default: %f)\n", params.eps);
+    fprintf(stderr, "  -ed FLOAT, --epsilon-decoder-transformer\n");
+    fprintf(stderr, "                        epsilon decoder transformer (default: %f)\n", params.eps_decoder_transformer);
+    fprintf(stderr, "SAM prompt:\n");
+    fprintf(stderr, "  -p TUPLE, --point-prompt\n");
+    fprintf(stderr, "                        point to be used as prompt for SAM (default: %f,%f). Must be in a format FLOAT,FLOAT \n", params.pt.x, params.pt.y);
     fprintf(stderr, "\n");
 }
 
@@ -2025,6 +2058,34 @@ bool sam_params_parse(int argc, char ** argv, sam_params & params) {
             params.fname_inp = argv[++i];
         } else if (arg == "-o" || arg == "--out") {
             params.fname_out = argv[++i];
+        } else if (arg == "-mt" || arg == "--mask-threshold") {
+            params.mask_threshold = std::stof(argv[++i]);
+        } else if (arg == "-it" || arg == "--iou-threshold") {
+            params.iou_threshold = std::stof(argv[++i]);
+        } else if (arg == "-st" || arg == "--score-threshold") {
+            params.stability_score_threshold = std::stof(argv[++i]);
+        } else if (arg == "-so" || arg == "--score-offset") {
+            params.stability_score_offset = std::stof(argv[++i]);
+        } else if (arg == "-e" || arg == "--epsilon") {
+            params.eps = std::stof(argv[++i]);
+        } else if (arg == "-ed" || arg == "--epsilon-decoder-transformer") {
+            params.eps_decoder_transformer = std::stof(argv[++i]);
+        } else if (arg == "-p" || arg == "--point-prompt") {
+            // TODO multiple points per model invocation
+            char* point = argv[++i];
+
+            char* coord = strtok(point, ",");
+            if (!coord){
+                fprintf(stderr, "Error while parsing prompt!\n");
+                exit(1);
+            }
+            params.pt.x = std::stof(coord);
+            coord = strtok(NULL, ",");
+            if (!coord){
+                fprintf(stderr, "Error while parsing prompt!\n");
+                exit(1);
+            }
+            params.pt.y = std::stof(coord);
         } else if (arg == "-h" || arg == "--help") {
             sam_print_usage(argc, argv, params);
             exit(0);
@@ -2078,7 +2139,7 @@ int main(int argc, char ** argv) {
     {
         const int64_t t_start_us = ggml_time_us();
 
-        if (!sam_model_load(params.model, model)) {
+        if (!sam_model_load(params, model)) {
             fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
             return 1;
         }
@@ -2147,10 +2208,11 @@ int main(int argc, char ** argv) {
         state.buf_compute_fast.resize(ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead());
         state.allocr = ggml_allocr_new_measure(tensor_alignment);
 
-        // TODO: user input
-        const sam_point pt = { 414.375f, 162.796875f, };
+        // TODO: more varied prompts
+        fprintf(stderr, "prompt: (%f, %f)\n", params.pt.x, params.pt.y);
+
         // measure memory requirements for the graph
-        struct ggml_cgraph  * gf_measure = sam_build_fast_graph(model, state, img0.nx, img0.ny, pt);
+        struct ggml_cgraph  * gf_measure = sam_build_fast_graph(model, state, img0.nx, img0.ny, params.pt);
         if (!gf_measure) {
             fprintf(stderr, "%s: failed to build fast graph to measure\n", __func__);
             return 1;
@@ -2166,7 +2228,7 @@ int main(int argc, char ** argv) {
         // compute the graph with the measured exact memory requirements from above
         ggml_allocr_reset(state.allocr);
 
-        struct ggml_cgraph  * gf = sam_build_fast_graph(model, state, img0.nx, img0.ny, pt);
+        struct ggml_cgraph  * gf = sam_build_fast_graph(model, state, img0.nx, img0.ny, params.pt);
         if (!gf) {
             fprintf(stderr, "%s: failed to build fast graph\n", __func__);
             return 1;
@@ -2182,7 +2244,7 @@ int main(int argc, char ** argv) {
         state.allocr = NULL;
     }
 
-    if (!sam_write_masks(model.hparams, img0.nx, img0.ny, state)) {
+    if (!sam_write_masks(model.hparams, img0.nx, img0.ny, state, params.fname_out)) {
         fprintf(stderr, "%s: failed to write masks\n", __func__);
         return 1;
     }