]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
mnist : remove redundant stuff + rename ctx0
authorGeorgi Gerganov <redacted>
Mon, 29 May 2023 18:14:52 +0000 (21:14 +0300)
committerGeorgi Gerganov <redacted>
Mon, 29 May 2023 18:15:17 +0000 (21:15 +0300)
examples/mnist/main-cpu.cpp
examples/mnist/main-mtl.cpp

index 42a29407bb6c1d66153e20202989ab980aa5bbfb..bcb402da3fbe7a7b86a029d1e8cf6b63c938a8cf 100644 (file)
@@ -39,7 +39,7 @@ int mnist_eval(
     struct ggml_cgraph gfi = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
     gfi.n_threads = n_threads;
 
-    // allocate eval context
+    // allocate work context
     // needed during ggml_graph_compute() to allocate a work tensor
     static size_t buf_size = gfi.work_size; // TODO
     static void * buf = malloc(buf_size);
@@ -50,18 +50,18 @@ int mnist_eval(
         .no_alloc   = false,
     };
 
-    struct ggml_context * ctx0 = ggml_init(params);
+    struct ggml_context * ctx_work = ggml_init(params);
 
     struct ggml_tensor * input = ggml_graph_get_tensor(&gfi, "input");
     memcpy(input->data, digit.data(), ggml_nbytes(input));
 
-    ggml_graph_compute(ctx0, &gfi);
+    ggml_graph_compute(ctx_work, &gfi);
 
     const float * probs_data = ggml_get_data_f32(ggml_graph_get_tensor(&gfi, "probs"));
 
     const int prediction = std::max_element(probs_data, probs_data + 10) - probs_data;
 
-    ggml_free(ctx0);
+    ggml_free(ctx_work);
     ggml_free(ctx_data);
     ggml_free(ctx_eval);
 
index fafe8e6102aab8bbe99a901a700e6a02c9be8185..1f9475d85aa509aa30b366cf99aef713a6a58070 100644 (file)
 // evaluate the MNIST compute graph
 //
 //   - fname_cgraph: path to the compute graph
-//   - n_threads:    number of threads to use
 //   - digit:        784 pixel values
 //
 // returns 0 - 9 prediction
 int mnist_eval(
         const char * fname_cgraph,
-        const int n_threads,
         std::vector<float> digit
         ) {
     // load the compute graph
@@ -38,10 +36,9 @@ int mnist_eval(
     struct ggml_context * ctx_eval = NULL;
 
     struct ggml_cgraph gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
-    gf.n_threads = n_threads;
+    gf.n_threads = 1;
 
-    // allocate eval context
-    // needed during ggml_graph_compute() to allocate a work tensor
+    // allocate work context
     static size_t buf_size = gf.work_size; // TODO
     static void * buf = malloc(buf_size);
 
@@ -121,7 +118,7 @@ int main(int argc, char ** argv) {
         fprintf(stderr, "\n");
     }
 
-    const int prediction = mnist_eval(argv[1], 1, digit);
+    const int prediction = mnist_eval(argv[1], digit);
 
     fprintf(stdout, "%s: predicted digit is %d\n", __func__, prediction);