]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Update README.md + minor stuff
authorGeorgi Gerganov <redacted>
Mon, 19 Sep 2022 21:09:34 +0000 (00:09 +0300)
committerGeorgi Gerganov <redacted>
Mon, 19 Sep 2022 21:09:34 +0000 (00:09 +0300)
- Changed default threads to 4
- Added GGML_PERF for enabling runtime performance timings

README.md
examples/gpt-2/README.md
examples/gpt-j/README.md
examples/utils.h
src/CMakeLists.txt
src/ggml.c

index 550425b3e11e9986268f9c49644efd8ae61074fa..9678e3947fd4ed69e3787b6d540c80824ea8449e 100644 (file)
--- a/README.md
+++ b/README.md
@@ -1,24 +1,23 @@
 # ggml
 
-Tensor library in C for machine learning
+Tensor library for machine learning
 
 ## Features
 
-- Automatic differentiation (WIP)
+- Written in C
 - 16-bit float support
+- Automatic differentiation (WIP in progress)
 - ADAM and L-BFGS optimizers
-- Optimized for Arm64 architectures (i.e. MacBook M1) via NEON intrinsics
+- Optimized for Arm64 architectures (M1) via NEON intrinsics
 - On x86 architectures utilzes AVX intrinsics
 - No third-party dependencies
 - Zero memory allocations during runtime
 
-## Local GPT inference
+## Example - GPT inference
 
-Using ggml you can run [GPT-2](examples/gpt-2) and [GPT-J](examples/gpt-j) inference locally on your computer without any additional software or hardware. You don't even need to install python or any other third-party library.
+With ggml you can efficiently run [GPT-2](examples/gpt-2) and [GPT-J](examples/gpt-j) inference on the CPU.
 
-The example programs are implemented in C++. They run entirely on the CPU.
-
-Here is how to use them:
+Here is how to run the example programs:
 
 ```bash
 # Build ggml + examples
@@ -37,7 +36,7 @@ make -j4 gpt-2 gpt-j
 ./bin/gpt-j -m models/gpt-j-6B/ggml-model.bin -p "This is an example"
 ```
 
-This is the inference speed for the different models on my MacBook M1 Pro:
+The inference speeds that I get for the different models on my 32GB MacBook M1 Pro are as follows:
 
 | Model | Size  | Time / Token |
 | ---   | ---   | ---    |
index 3543bb29504ef623aae144edc1aefb874d7c87fa..60fea55dc759248172ff2e5d2b97cb239499edfe 100644 (file)
@@ -1,7 +1,6 @@
 # gpt-2
 
 This is a C++ example running GPT-2 inference using the [ggml](https://github.com/ggerganov/ggml) library.
-The enitre code of the example is in [main.cpp](main.cpp).
 
 The program runs on the CPU - no video card is required.
 
@@ -73,11 +72,11 @@ main:    total time =   629.84 ms
 
 ## Downloading and converting the original models
 
-You can download the original model files using the [download-model.sh](download-model.sh) Bash script.
-The model is in Tensorflow format, so before using it with ggml, we need to convert it to appropriate format.
-This is done via the [convert-ckpt-to-ggml.py](convert-ckpt-to-ggml.py) python script.
+You can download the original model files using the [download-model.sh](download-model.sh) Bash script. The models are
+in Tensorflow format, so in order to use them with ggml, you need to convert them to appropriate format. This is done
+via the [convert-ckpt-to-ggml.py](convert-ckpt-to-ggml.py) python script.
 
-Here is the entire process for the GPT-2 117M model:
+Here is the entire process for the GPT-2 117M model (download from official site + conversion):
 
 ```
 cd ggml/build
@@ -99,14 +98,13 @@ Run the convert-ckpt-to-ggml.py script to convert the model to ggml format.
 
 ```
 
-This conversion requires that you have python and Tensorflow installed on your computer.
-Still, if you want to avoid this, you can download the already converted ggml models as
-described below.
+This conversion requires that you have python and Tensorflow installed on your computer. Still, if you want to avoid
+this, you can download the already converted ggml models as described below.
 
 ## Downloading the ggml model directly
 
-For convenience, I will be hosting the converted ggml model files in order to make it easier to run the examples.
-This way, you can directly download a single binary file and start using it. No python or Tensorflow is required.
+For convenience, I will be hosting the converted ggml model files in order to make it easier to run the examples. This
+way, you can directly download a single binary file and start using it. No python or Tensorflow is required.
 
 Here is how to get the 117M ggml model:
 
@@ -123,4 +121,4 @@ You can now use it like this:
 
 ```
 
-At some point, I might stop hosting these models. So in that case, simply revert to the manual process above.
+At some point, I might decide to stop hosting these models. So in that case, simply revert to the manual process above.
index c5e0007cf0f4d86ae148252cbf84842a547b9e36..68c41361a805b9d10d6ae4ce6fd27a5e70ffd292 100644 (file)
@@ -4,25 +4,23 @@ Local GPT-J inference on your computer using C/C++
 
 No video card required. You just need to have 16 GB of RAM.
 
-For example, you can run this on a 16 GB MacBook M1.
-
 ## Motivation
 
-The GPT-J 6B model is the open-source alternative to OpenAI's GPT-3. It's basically a neural network that
-allows you to generate coherent, human-like text given a certain context (prompt).
+The GPT-J 6B model is the open-source alternative to OpenAI's GPT-3. It's basically a neural network that allows you to
+generate coherent, human-like text given a certain context (prompt).
 
-The GPT-J model is quite big - the compact version of the model uses 16-bit floating point representation
-of the weights and is still 12 GB big. This means that in order to run inference on your computer, you
-would need to have a video card with at least 12 GB of video RAM. Alternatively, you can try to run the
-python implementations on the CPU, but that would probably not be very efficient as they are primarily
-optimized for running on a GPU (or at least this is my guess - I don't have much experience with python).
+The GPT-J model is quite big - the compact version of the model uses 16-bit floating point representation of the weights
+and is still 12 GB big. This means that in order to run inference on your computer, you would need to have a video card
+with at least 12 GB of video RAM. Alternatively, you can try to run the python implementations on the CPU, but that
+would probably not be very efficient as they are primarily optimized for running on a GPU (or at least this is my guess -
+I don't have much experience with python).
 
-Looking on the internet, I couldn't find a dedicated CPU implementation that would allow me to run the model
-without a high-end video card. So I decided to write my own inference using a custom build tensor library.
-The tensor library (called [ggml](https://github.com/ggerganov/ggml), written in C) is in early development
-stage, but it already allows me to run the GPT-J model.
+I wanted to try and run the model on my MacBook, so I decided to implement the model inference from scratch using my own
+custom build tensor library. The tensor library (called [ggml](https://github.com/ggerganov/ggml), written in C) is in
+early development stage, but it already allows me to run the GPT-J model.
 
-On my MacBook M1 Pro, I achieve an inference speed of about `125 ms/token` or about 2-3 words per second.
+On my 32GB MacBook M1 Pro, I achieve an inference speed of about `125 ms/token` or about ~6 words per second (1 word
+typically consists of 1 or 2 tokens).
 
 Here is a sample run with prompt `int main(int argc, char ** argv) {`:
 
@@ -68,51 +66,133 @@ main:    total time = 33035.37 ms
 
 real   0m33.171s
 user   3m32.269s
-sys         0m3.686s
+sys      0m3.686s
 
 $
 ```
 
-It took ~6.2 seconds to load the model to memory. After that, it took ~26.4 seconds to generate 200
-tokens of what looks like to be the beginning of a networking program in C. Pretty cool!
+It took ~6.2 seconds to load the model to memory. After that, it took ~26.4 seconds to generate 200 tokens of what
+looks like to be the beginning of a networking program in C. Pretty cool!
+
+Here is another run, just for fun:
+
+```
+time ./bin/gpt-j -n 500 -t 8 -p "Ask HN: Inherited the worst code and tech team I have ever seen. How to fix it?
+"
+
+gptj_model_load: loading model from 'models/gpt-j-6B/ggml-model.bin' - please wait ...
+gptj_model_load: n_vocab = 50400
+gptj_model_load: n_ctx   = 2048
+gptj_model_load: n_embd  = 4096
+gptj_model_load: n_head  = 16
+gptj_model_load: n_layer = 28
+gptj_model_load: n_rot   = 64
+gptj_model_load: f16     = 1
+gptj_model_load: ggml ctx size = 13334.86 MB
+gptj_model_load: memory_size =  1792.00 MB, n_mem = 57344
+gptj_model_load: ................................... done
+gptj_model_load: model size = 11542.79 MB / num tensors = 285
+main: number of tokens in prompt = 24
+
+Ask HN: Inherited the worst code and tech team I have ever seen. How to fix it?
+
+I've inherited a team with some very strange and un-documented practices, one of them is that they use an old custom
+application with a very slow tech stack written in Python that the team doesn't want to touch but also doesn't want to
+throw away as it has some "legacy" code in it.
+
+The problem is, the tech stack is very very slow.
+
+They have a single web server on a VM that is slow.
+The server is a little bit busy (not very busy though) and they have a lot of processes (30+ that are constantly being
+spawned by the application)
+They have an application that is single threaded and was written in Python and the team don't want to touch this, and
+the application is very slow.
+
+My task as a new member of the team is to fix this.
+
+I'm a senior dev on the team (3 years on the project) and have been told that I will take the lead on this task. I know
+next to nothing about Python. So here is what I have so far.
+
+What I have done is I've been trying to debug the processes with the "ps" command. This way I can see what is running
+and where. From what I see, the application spawns 10 processes a minute and some of them are used for nothing.
+
+I have also started to look for the code. The application source is not in GitHub or any other repository, it is only on
+our internal GitLab.
+
+What I've found so far:
+
+The application uses a custom SQLAlchemy implementation to interact with the data. I've looked at the source, it looks
+like an object cache or something like that. But from what I've seen, the cache gets full every 20 minutes and then gets
+cleared with a special command.
+
+Another strange thing is that the application creates a file for every entry in the database (even if the entry already
+exists). I've looked at the file to see if it contains something, but it seems to be a JSON file with lots of records.
+
+The other strange thing is that I can only find the database tables in the GitLab repository and not the code. So I
+can't really understand how the application is supposed to interact with the database.
+
+I also found a "log" directory, but the code is encrypted with AES. From what I've found, it is in
+
+main: mem per token = 16430420 bytes
+main:     load time =  3900.10 ms
+main:   sample time =    32.58 ms
+main:  predict time = 68049.91 ms / 130.11 ms per token
+main:    total time = 73020.05 ms
+
+real   1m13.156s
+user   9m1.328s
+sys.    0m7.103s
+```
 
 ## Implementation details
 
-The high level implementation of the model is contained in the [main.cpp](main.cpp) file. The core
-computations are performed by the `ggml` library.
+The high level implementation of the model is contained in the [main.cpp](main.cpp) file. The core computations are
+performed by the [ggml](https://github.com/ggerganov/ggml/blob/master/include/ggml/ggml.h) library.
+
+
+#### Matrix multiplication
 
-The most performance critical part of the implementation is of course the matrix multiplication routine.
-99% of the time is spent here, so it is important to optimize this as much as possible.
+The most performance critical part of the implementation is of course the matrix multiplication routine. 99% of the time
+is spent here, so it was important to optimize this as much as possible.
 
 On Arm64, I utilize the 128-bit NEON intrinsics for 16-bit floating point operations:
 
 https://github.com/ggerganov/ggml/blob/fb558f78d905f85c54813602649ddd628ffe0f3a/src/ggml.c#L187-L243
 
-These instructions allow each core to operate simultaneously on 64 floating point numbers. I'm no expert
-in SIMD, but after quite some trials this was the most efficient code for dot product that I could come up
-with. Combined with the parallel computation on 8 CPU threads, I think I got close to the maximum performance
-that one could possibly get on the M1 CPU. Still, I'm curious to know if there is a more efficient way to
-implement this.
+These instructions allow each core to operate simultaneously on 64 16-bit floats. I'm no expert in SIMD, but after quite
+some trials this was the most efficient code for dot product of a row and column that I could come up with. Combined
+with the parallel computation on 8 CPU threads, I believe I'm close to the maximum performance that one could possibly
+get on the M1 CPU. Still, I'm curious to know if there is a more efficient way to implement this.
+
+
+#### Attempt to use the M1 GPU
 
-One interesting property of the GPT-J transformer architecture is that it allows you to perform part
-of the inference in parallel - i.e. the Feed-forward layer can be computed in parallel to the Self-Attention
-layer:
+One interesting property of the GPT-J transformer architecture is that it allows you to perform part of the inference in
+parallel - i.e. the Feed-forward network can be computed in parallel to the Self-attention layer:
 
 https://github.com/ggerganov/ggml/blob/fb558f78d905f85c54813602649ddd628ffe0f3a/examples/gpt-j/main.cpp#L507-L531
 
-So I thought why not bring in the M1 GPU to compute half of the neural network in parallel to the CPU.
-Thanks to the shared memory model, it was relatively easy to offload half of the computation to the GPU
-using [Metal Performance Shaders](https://developer.apple.com/documentation/metalperformanceshaders).
-However, to my surprise, I did not get any performance improvement at all. My conclusion was that the
-8-thread NEON CPU computation is basically saturating the memory bandwidth of the M1 and since the CPU
-and the GPU on the MacBook are sharing that bandwidth, it does not help to offload the computation to the
-GPU. Another observation was that the MPS GPU matrix multiplication using 16-bit floats had the same
-performance as the 8-thread NEON CPU implementation. Again, I explain this with a saturated memory channel.
-But of course, I could be totally wrong and somehow my implementation wasn't utilizing the resources 
-correctly.
+So I thought why not try and bring in the M1 GPU to compute half of the neural network in parallel to the CPU and
+potentially gain some extra performance. Thanks to the M1's shared memory model, it was relatively easy to offload part
+of the computation to the GPU using Apple's [Metal Performance
+Shaders](https://developer.apple.com/documentation/metalperformanceshaders). The GPU shares the host memory, so there is
+no need to copy the data back and forth as you would normally do with Cuda or OpenCL. The weight matrices are directly
+available to be used by the GPU.
 
-Another property of my implementation is that it does not perform any memory allocations once the model
-is loaded into memory. All required memory is allocated at the start of the program.
+However, to my surprise, using MPS together with the CPU did not lead to any performance improvement at all. My
+conclusion was that the 8-thread NEON CPU computation is already saturating the memory bandwidth of the M1 and since
+the CPU and the GPU on the MacBook are sharing that bandwidth, it does not help to offload the computation to the GPU.
+Another observation was that the MPS GPU matrix multiplication using 16-bit floats had the same performance as the
+8-thread NEON CPU implementation. Again, I explain this with a saturated memory channel. But of course, my explanation
+could be totally wrong and somehow the implementation wasn't utilizing the resources correctly.
+
+In the end, I decided to not use MPS or the GPU all together.
+
+### Zero memory allocations
+
+Another property of my implementation is that it does not perform any memory allocations once the model is loaded into
+memory. All required memory is allocated at the start of the program with a single `malloc` (technically 2 calls, but
+that is not important).
 
 ## Usage
 
@@ -134,22 +214,26 @@ make -j4 gpt-j
 ```
 
 To run the `gpt-j` tool, you need the 12GB `ggml-model.bin` file which contains the GPT-J model in
-[ggml](https://github.com/ggerganov/ggml) format. In the instructions above, I download the binary file
+[ggml](https://github.com/ggerganov/ggml) compatible format. In the instructions above, I download the binary file
 directly from one of my servers, using the [download-ggml-model.sh](download-ggml-model.sh) script.
 
 ---
 
-Alternatively, you can perform the conversion yourself.
+Alternatively, if you don't want to download the 12GB ggml model file, you can perform the conversion yourself using
+python.
 
 First, you need to download the full GPT-J model from here: https://huggingface.co/EleutherAI/gpt-j-6B
 
-Note that the full model is quite big - about 72 GB. After you download it, you need to make the
-conversion using the [convert-h5-to-ggml.py](convert-h5-to-ggml.py) script. This will generate the
-`ggml-model.bin` file, which you can then use with the `gpt-j` program.
+Note that the full model is quite big - about 72 GB. After you download it, you need to convert it to ggml format using
+the [convert-h5-to-ggml.py](convert-h5-to-ggml.py) script. This will generate the `ggml-model.bin` file, which you can
+then use with the `gpt-j` program.
+
 
 ## GPT-2
 
-I have also implemented a tool for CPU inference using the smaller GPT-2 models. They have worse
-quality compared to GPT-J, but are much faster to execute.
+I also implemented a tool for CPU inference using the smaller GPT-2 models. They have worse quality compared to GPT-J,
+but are much faster to execute.
+
+For example, the Small GPT-2 model is only 240 MB big and the inference speed on my MacBook is about 200 tokens/sec.
 
-Checkout the GPT-2 example here: [gpt-2](https://github.com/ggerganov/ggml/tree/master/examples/gpt-2)
+For more details, checkout the GPT-2 example here: [gpt-2](https://github.com/ggerganov/ggml/tree/master/examples/gpt-2)
index aee9abfffa8c6081e0c8dd3b25f0760e27872047..92ce0f25815a0a84257a1ff57ffdd19ed3f93d4b 100644 (file)
@@ -14,7 +14,7 @@
 
 struct gpt_params {
     int32_t seed      = -1; // RNG seed
-    int32_t n_threads = std::min(8, (int32_t) std::thread::hardware_concurrency());
+    int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
     int32_t n_predict = 200; // new tokens to predict
 
     // sampling parameters
index c8a6396d0a20703e41f7b442f2e4f4b7eec14a46..af40156fe50e917ec03ab8022ee8484282405650 100644 (file)
@@ -48,6 +48,10 @@ set(TARGET ggml)
 #    endif()
 #endif()
 
+if (GGML_PERF)
+    set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_PERF)
+endif()
+
 add_library(${TARGET}
     ggml.c
     )
index bef3dc50581d3a8029623959d39f3a6b85c1ae83..726352b0b8dc47ef4f025d88bc00a7ce19a6b7bf 100644 (file)
 #include <pthread.h>
 
 #define GGML_DEBUG 0
+#define GGML_MEM_ALIGN 16
 
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
 #define MIN(a, b) ((a) < (b) ? (a) : (b))
 
-#define GGML_MEM_ALIGN 16
-
 #define UNUSED(x) (void)(x)
 #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
 
@@ -117,7 +116,6 @@ ggml_fp16_t ggml_fp32_to_fp16(float f) {
 // timing
 //
 
-// TODO: need to be able to disable these in performance critical code since they make slow system calls
 int64_t ggml_time_ms(void) {
     struct timespec ts;
     clock_gettime(CLOCK_MONOTONIC, &ts);
@@ -138,6 +136,18 @@ int64_t ggml_cycles_per_ms(void) {
     return CLOCKS_PER_SEC/1000;
 }
 
+#ifdef GGML_PERF
+#define ggml_perf_time_ms()       ggml_time_ms()
+#define ggml_perf_time_us()       ggml_time_us()
+#define ggml_perf_cycles()        ggml_cycles()
+#define ggml_perf_cycles_per_ms() ggml_cycles_per_ms()
+#else
+#define ggml_perf_time_ms()       0
+#define ggml_perf_time_us()       0
+#define ggml_perf_cycles()        0
+#define ggml_perf_cycles_per_ms() 0
+#endif
+
 //
 // cache line
 //
@@ -3053,7 +3063,7 @@ void ggml_compute_forward_mul_mat_f32(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
               struct ggml_tensor * dst) {
-    int64_t t0 = ggml_time_us();
+    int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
     const int ne00 = src0->ne[0];
@@ -3232,7 +3242,7 @@ void ggml_compute_forward_mul_mat_f32(
         }
     }
 
-    //int64_t t1 = ggml_time_us();
+    //int64_t t1 = ggml_perf_time_us();
     //static int64_t acc = 0;
     //acc += t1 - t0;
     //if (t1 - t0 > 10) {
@@ -3251,7 +3261,7 @@ void ggml_compute_forward_mul_mat_f16_f32(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
               struct ggml_tensor * dst) {
-    int64_t t0 = ggml_time_us();
+    int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
     const int ne00 = src0->ne[0];
@@ -4619,8 +4629,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
         }
     }
 
-    const int64_t perf_start_cycles  = ggml_cycles();
-    const int64_t perf_start_time_us = ggml_time_us();
+    const int64_t perf_start_cycles  = ggml_perf_cycles();
+    const int64_t perf_start_time_us = ggml_perf_time_us();
 
     for (int i = 0; i < cgraph->n_nodes; i++) {
         GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, i, cgraph->n_nodes);
@@ -4632,8 +4642,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
         //    continue;
         //}
 
-        const int64_t perf_node_start_cycles  = ggml_cycles();
-        const int64_t perf_node_start_time_us = ggml_time_us();
+        const int64_t perf_node_start_cycles  = ggml_perf_cycles();
+        const int64_t perf_node_start_time_us = ggml_perf_time_us();
 
         // INIT
         struct ggml_compute_params params = {
@@ -4706,8 +4716,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
 
         // performance stats (node)
         {
-            int64_t perf_cycles_cur  = ggml_cycles()  - perf_node_start_cycles;
-            int64_t perf_time_us_cur = ggml_time_us() - perf_node_start_time_us;
+            int64_t perf_cycles_cur  = ggml_perf_cycles()  - perf_node_start_cycles;
+            int64_t perf_time_us_cur = ggml_perf_time_us() - perf_node_start_time_us;
 
             node->perf_runs++;
             node->perf_cycles  += perf_cycles_cur;
@@ -4731,8 +4741,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
 
     // performance stats (graph)
     {
-        int64_t perf_cycles_cur  = ggml_cycles()  - perf_start_cycles;
-        int64_t perf_time_us_cur = ggml_time_us() - perf_start_time_us;
+        int64_t perf_cycles_cur  = ggml_perf_cycles()  - perf_start_cycles;
+        int64_t perf_time_us_cur = ggml_perf_time_us() - perf_start_time_us;
 
         cgraph->perf_runs++;
         cgraph->perf_cycles  += perf_cycles_cur;