From: Georgi Gerganov Date: Mon, 19 Sep 2022 21:09:34 +0000 (+0300) Subject: Update README.md + minor stuff X-Git-Tag: upstream/0.0.1642~1621 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=f21b84cd217a4cb39cc21c6dcfce30e0bff132f0;p=pkg%2Fggml%2Fsources%2Fggml Update README.md + minor stuff - Changed default threads to 4 - Added GGML_PERF for enabling runtime performance timings --- diff --git a/README.md b/README.md index 550425b3..9678e394 100644 --- a/README.md +++ b/README.md @@ -1,24 +1,23 @@ # ggml -Tensor library in C for machine learning +Tensor library for machine learning ## Features -- Automatic differentiation (WIP) +- Written in C - 16-bit float support +- Automatic differentiation (WIP in progress) - ADAM and L-BFGS optimizers -- Optimized for Arm64 architectures (i.e. MacBook M1) via NEON intrinsics +- Optimized for Arm64 architectures (M1) via NEON intrinsics - On x86 architectures utilzes AVX intrinsics - No third-party dependencies - Zero memory allocations during runtime -## Local GPT inference +## Example - GPT inference -Using ggml you can run [GPT-2](examples/gpt-2) and [GPT-J](examples/gpt-j) inference locally on your computer without any additional software or hardware. You don't even need to install python or any other third-party library. +With ggml you can efficiently run [GPT-2](examples/gpt-2) and [GPT-J](examples/gpt-j) inference on the CPU. -The example programs are implemented in C++. They run entirely on the CPU. - -Here is how to use them: +Here is how to run the example programs: ```bash # Build ggml + examples @@ -37,7 +36,7 @@ make -j4 gpt-2 gpt-j ./bin/gpt-j -m models/gpt-j-6B/ggml-model.bin -p "This is an example" ``` -This is the inference speed for the different models on my MacBook M1 Pro: +The inference speeds that I get for the different models on my 32GB MacBook M1 Pro are as follows: | Model | Size | Time / Token | | --- | --- | --- | diff --git a/examples/gpt-2/README.md b/examples/gpt-2/README.md index 3543bb29..60fea55d 100644 --- a/examples/gpt-2/README.md +++ b/examples/gpt-2/README.md @@ -1,7 +1,6 @@ # gpt-2 This is a C++ example running GPT-2 inference using the [ggml](https://github.com/ggerganov/ggml) library. -The enitre code of the example is in [main.cpp](main.cpp). The program runs on the CPU - no video card is required. @@ -73,11 +72,11 @@ main: total time = 629.84 ms ## Downloading and converting the original models -You can download the original model files using the [download-model.sh](download-model.sh) Bash script. -The model is in Tensorflow format, so before using it with ggml, we need to convert it to appropriate format. -This is done via the [convert-ckpt-to-ggml.py](convert-ckpt-to-ggml.py) python script. +You can download the original model files using the [download-model.sh](download-model.sh) Bash script. The models are +in Tensorflow format, so in order to use them with ggml, you need to convert them to appropriate format. This is done +via the [convert-ckpt-to-ggml.py](convert-ckpt-to-ggml.py) python script. -Here is the entire process for the GPT-2 117M model: +Here is the entire process for the GPT-2 117M model (download from official site + conversion): ``` cd ggml/build @@ -99,14 +98,13 @@ Run the convert-ckpt-to-ggml.py script to convert the model to ggml format. ``` -This conversion requires that you have python and Tensorflow installed on your computer. -Still, if you want to avoid this, you can download the already converted ggml models as -described below. +This conversion requires that you have python and Tensorflow installed on your computer. Still, if you want to avoid +this, you can download the already converted ggml models as described below. ## Downloading the ggml model directly -For convenience, I will be hosting the converted ggml model files in order to make it easier to run the examples. -This way, you can directly download a single binary file and start using it. No python or Tensorflow is required. +For convenience, I will be hosting the converted ggml model files in order to make it easier to run the examples. This +way, you can directly download a single binary file and start using it. No python or Tensorflow is required. Here is how to get the 117M ggml model: @@ -123,4 +121,4 @@ You can now use it like this: ``` -At some point, I might stop hosting these models. So in that case, simply revert to the manual process above. +At some point, I might decide to stop hosting these models. So in that case, simply revert to the manual process above. diff --git a/examples/gpt-j/README.md b/examples/gpt-j/README.md index c5e0007c..68c41361 100644 --- a/examples/gpt-j/README.md +++ b/examples/gpt-j/README.md @@ -4,25 +4,23 @@ Local GPT-J inference on your computer using C/C++ No video card required. You just need to have 16 GB of RAM. -For example, you can run this on a 16 GB MacBook M1. - ## Motivation -The GPT-J 6B model is the open-source alternative to OpenAI's GPT-3. It's basically a neural network that -allows you to generate coherent, human-like text given a certain context (prompt). +The GPT-J 6B model is the open-source alternative to OpenAI's GPT-3. It's basically a neural network that allows you to +generate coherent, human-like text given a certain context (prompt). -The GPT-J model is quite big - the compact version of the model uses 16-bit floating point representation -of the weights and is still 12 GB big. This means that in order to run inference on your computer, you -would need to have a video card with at least 12 GB of video RAM. Alternatively, you can try to run the -python implementations on the CPU, but that would probably not be very efficient as they are primarily -optimized for running on a GPU (or at least this is my guess - I don't have much experience with python). +The GPT-J model is quite big - the compact version of the model uses 16-bit floating point representation of the weights +and is still 12 GB big. This means that in order to run inference on your computer, you would need to have a video card +with at least 12 GB of video RAM. Alternatively, you can try to run the python implementations on the CPU, but that +would probably not be very efficient as they are primarily optimized for running on a GPU (or at least this is my guess - +I don't have much experience with python). -Looking on the internet, I couldn't find a dedicated CPU implementation that would allow me to run the model -without a high-end video card. So I decided to write my own inference using a custom build tensor library. -The tensor library (called [ggml](https://github.com/ggerganov/ggml), written in C) is in early development -stage, but it already allows me to run the GPT-J model. +I wanted to try and run the model on my MacBook, so I decided to implement the model inference from scratch using my own +custom build tensor library. The tensor library (called [ggml](https://github.com/ggerganov/ggml), written in C) is in +early development stage, but it already allows me to run the GPT-J model. -On my MacBook M1 Pro, I achieve an inference speed of about `125 ms/token` or about 2-3 words per second. +On my 32GB MacBook M1 Pro, I achieve an inference speed of about `125 ms/token` or about ~6 words per second (1 word +typically consists of 1 or 2 tokens). Here is a sample run with prompt `int main(int argc, char ** argv) {`: @@ -68,51 +66,133 @@ main: total time = 33035.37 ms real 0m33.171s user 3m32.269s -sys 0m3.686s +sys 0m3.686s $ ``` -It took ~6.2 seconds to load the model to memory. After that, it took ~26.4 seconds to generate 200 -tokens of what looks like to be the beginning of a networking program in C. Pretty cool! +It took ~6.2 seconds to load the model to memory. After that, it took ~26.4 seconds to generate 200 tokens of what +looks like to be the beginning of a networking program in C. Pretty cool! + +Here is another run, just for fun: + +``` +time ./bin/gpt-j -n 500 -t 8 -p "Ask HN: Inherited the worst code and tech team I have ever seen. How to fix it? +" + +gptj_model_load: loading model from 'models/gpt-j-6B/ggml-model.bin' - please wait ... +gptj_model_load: n_vocab = 50400 +gptj_model_load: n_ctx = 2048 +gptj_model_load: n_embd = 4096 +gptj_model_load: n_head = 16 +gptj_model_load: n_layer = 28 +gptj_model_load: n_rot = 64 +gptj_model_load: f16 = 1 +gptj_model_load: ggml ctx size = 13334.86 MB +gptj_model_load: memory_size = 1792.00 MB, n_mem = 57344 +gptj_model_load: ................................... done +gptj_model_load: model size = 11542.79 MB / num tensors = 285 +main: number of tokens in prompt = 24 + +Ask HN: Inherited the worst code and tech team I have ever seen. How to fix it? + +I've inherited a team with some very strange and un-documented practices, one of them is that they use an old custom +application with a very slow tech stack written in Python that the team doesn't want to touch but also doesn't want to +throw away as it has some "legacy" code in it. + +The problem is, the tech stack is very very slow. + +They have a single web server on a VM that is slow. +The server is a little bit busy (not very busy though) and they have a lot of processes (30+ that are constantly being +spawned by the application) +They have an application that is single threaded and was written in Python and the team don't want to touch this, and +the application is very slow. + +My task as a new member of the team is to fix this. + +I'm a senior dev on the team (3 years on the project) and have been told that I will take the lead on this task. I know +next to nothing about Python. So here is what I have so far. + +What I have done is I've been trying to debug the processes with the "ps" command. This way I can see what is running +and where. From what I see, the application spawns 10 processes a minute and some of them are used for nothing. + +I have also started to look for the code. The application source is not in GitHub or any other repository, it is only on +our internal GitLab. + +What I've found so far: + +The application uses a custom SQLAlchemy implementation to interact with the data. I've looked at the source, it looks +like an object cache or something like that. But from what I've seen, the cache gets full every 20 minutes and then gets +cleared with a special command. + +Another strange thing is that the application creates a file for every entry in the database (even if the entry already +exists). I've looked at the file to see if it contains something, but it seems to be a JSON file with lots of records. + +The other strange thing is that I can only find the database tables in the GitLab repository and not the code. So I +can't really understand how the application is supposed to interact with the database. + +I also found a "log" directory, but the code is encrypted with AES. From what I've found, it is in + +main: mem per token = 16430420 bytes +main: load time = 3900.10 ms +main: sample time = 32.58 ms +main: predict time = 68049.91 ms / 130.11 ms per token +main: total time = 73020.05 ms + +real 1m13.156s +user 9m1.328s +sys. 0m7.103s +``` ## Implementation details -The high level implementation of the model is contained in the [main.cpp](main.cpp) file. The core -computations are performed by the `ggml` library. +The high level implementation of the model is contained in the [main.cpp](main.cpp) file. The core computations are +performed by the [ggml](https://github.com/ggerganov/ggml/blob/master/include/ggml/ggml.h) library. + + +#### Matrix multiplication -The most performance critical part of the implementation is of course the matrix multiplication routine. -99% of the time is spent here, so it is important to optimize this as much as possible. +The most performance critical part of the implementation is of course the matrix multiplication routine. 99% of the time +is spent here, so it was important to optimize this as much as possible. On Arm64, I utilize the 128-bit NEON intrinsics for 16-bit floating point operations: https://github.com/ggerganov/ggml/blob/fb558f78d905f85c54813602649ddd628ffe0f3a/src/ggml.c#L187-L243 -These instructions allow each core to operate simultaneously on 64 floating point numbers. I'm no expert -in SIMD, but after quite some trials this was the most efficient code for dot product that I could come up -with. Combined with the parallel computation on 8 CPU threads, I think I got close to the maximum performance -that one could possibly get on the M1 CPU. Still, I'm curious to know if there is a more efficient way to -implement this. +These instructions allow each core to operate simultaneously on 64 16-bit floats. I'm no expert in SIMD, but after quite +some trials this was the most efficient code for dot product of a row and column that I could come up with. Combined +with the parallel computation on 8 CPU threads, I believe I'm close to the maximum performance that one could possibly +get on the M1 CPU. Still, I'm curious to know if there is a more efficient way to implement this. + + +#### Attempt to use the M1 GPU -One interesting property of the GPT-J transformer architecture is that it allows you to perform part -of the inference in parallel - i.e. the Feed-forward layer can be computed in parallel to the Self-Attention -layer: +One interesting property of the GPT-J transformer architecture is that it allows you to perform part of the inference in +parallel - i.e. the Feed-forward network can be computed in parallel to the Self-attention layer: https://github.com/ggerganov/ggml/blob/fb558f78d905f85c54813602649ddd628ffe0f3a/examples/gpt-j/main.cpp#L507-L531 -So I thought why not bring in the M1 GPU to compute half of the neural network in parallel to the CPU. -Thanks to the shared memory model, it was relatively easy to offload half of the computation to the GPU -using [Metal Performance Shaders](https://developer.apple.com/documentation/metalperformanceshaders). -However, to my surprise, I did not get any performance improvement at all. My conclusion was that the -8-thread NEON CPU computation is basically saturating the memory bandwidth of the M1 and since the CPU -and the GPU on the MacBook are sharing that bandwidth, it does not help to offload the computation to the -GPU. Another observation was that the MPS GPU matrix multiplication using 16-bit floats had the same -performance as the 8-thread NEON CPU implementation. Again, I explain this with a saturated memory channel. -But of course, I could be totally wrong and somehow my implementation wasn't utilizing the resources -correctly. +So I thought why not try and bring in the M1 GPU to compute half of the neural network in parallel to the CPU and +potentially gain some extra performance. Thanks to the M1's shared memory model, it was relatively easy to offload part +of the computation to the GPU using Apple's [Metal Performance +Shaders](https://developer.apple.com/documentation/metalperformanceshaders). The GPU shares the host memory, so there is +no need to copy the data back and forth as you would normally do with Cuda or OpenCL. The weight matrices are directly +available to be used by the GPU. -Another property of my implementation is that it does not perform any memory allocations once the model -is loaded into memory. All required memory is allocated at the start of the program. +However, to my surprise, using MPS together with the CPU did not lead to any performance improvement at all. My +conclusion was that the 8-thread NEON CPU computation is already saturating the memory bandwidth of the M1 and since +the CPU and the GPU on the MacBook are sharing that bandwidth, it does not help to offload the computation to the GPU. +Another observation was that the MPS GPU matrix multiplication using 16-bit floats had the same performance as the +8-thread NEON CPU implementation. Again, I explain this with a saturated memory channel. But of course, my explanation +could be totally wrong and somehow the implementation wasn't utilizing the resources correctly. + +In the end, I decided to not use MPS or the GPU all together. + +### Zero memory allocations + +Another property of my implementation is that it does not perform any memory allocations once the model is loaded into +memory. All required memory is allocated at the start of the program with a single `malloc` (technically 2 calls, but +that is not important). ## Usage @@ -134,22 +214,26 @@ make -j4 gpt-j ``` To run the `gpt-j` tool, you need the 12GB `ggml-model.bin` file which contains the GPT-J model in -[ggml](https://github.com/ggerganov/ggml) format. In the instructions above, I download the binary file +[ggml](https://github.com/ggerganov/ggml) compatible format. In the instructions above, I download the binary file directly from one of my servers, using the [download-ggml-model.sh](download-ggml-model.sh) script. --- -Alternatively, you can perform the conversion yourself. +Alternatively, if you don't want to download the 12GB ggml model file, you can perform the conversion yourself using +python. First, you need to download the full GPT-J model from here: https://huggingface.co/EleutherAI/gpt-j-6B -Note that the full model is quite big - about 72 GB. After you download it, you need to make the -conversion using the [convert-h5-to-ggml.py](convert-h5-to-ggml.py) script. This will generate the -`ggml-model.bin` file, which you can then use with the `gpt-j` program. +Note that the full model is quite big - about 72 GB. After you download it, you need to convert it to ggml format using +the [convert-h5-to-ggml.py](convert-h5-to-ggml.py) script. This will generate the `ggml-model.bin` file, which you can +then use with the `gpt-j` program. + ## GPT-2 -I have also implemented a tool for CPU inference using the smaller GPT-2 models. They have worse -quality compared to GPT-J, but are much faster to execute. +I also implemented a tool for CPU inference using the smaller GPT-2 models. They have worse quality compared to GPT-J, +but are much faster to execute. + +For example, the Small GPT-2 model is only 240 MB big and the inference speed on my MacBook is about 200 tokens/sec. -Checkout the GPT-2 example here: [gpt-2](https://github.com/ggerganov/ggml/tree/master/examples/gpt-2) +For more details, checkout the GPT-2 example here: [gpt-2](https://github.com/ggerganov/ggml/tree/master/examples/gpt-2) diff --git a/examples/utils.h b/examples/utils.h index aee9abff..92ce0f25 100644 --- a/examples/utils.h +++ b/examples/utils.h @@ -14,7 +14,7 @@ struct gpt_params { int32_t seed = -1; // RNG seed - int32_t n_threads = std::min(8, (int32_t) std::thread::hardware_concurrency()); + int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); int32_t n_predict = 200; // new tokens to predict // sampling parameters diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c8a6396d..af40156f 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -48,6 +48,10 @@ set(TARGET ggml) # endif() #endif() +if (GGML_PERF) + set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_PERF) +endif() + add_library(${TARGET} ggml.c ) diff --git a/src/ggml.c b/src/ggml.c index bef3dc50..726352b0 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -12,12 +12,11 @@ #include #define GGML_DEBUG 0 +#define GGML_MEM_ALIGN 16 #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define GGML_MEM_ALIGN 16 - #define UNUSED(x) (void)(x) #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0) @@ -117,7 +116,6 @@ ggml_fp16_t ggml_fp32_to_fp16(float f) { // timing // -// TODO: need to be able to disable these in performance critical code since they make slow system calls int64_t ggml_time_ms(void) { struct timespec ts; clock_gettime(CLOCK_MONOTONIC, &ts); @@ -138,6 +136,18 @@ int64_t ggml_cycles_per_ms(void) { return CLOCKS_PER_SEC/1000; } +#ifdef GGML_PERF +#define ggml_perf_time_ms() ggml_time_ms() +#define ggml_perf_time_us() ggml_time_us() +#define ggml_perf_cycles() ggml_cycles() +#define ggml_perf_cycles_per_ms() ggml_cycles_per_ms() +#else +#define ggml_perf_time_ms() 0 +#define ggml_perf_time_us() 0 +#define ggml_perf_cycles() 0 +#define ggml_perf_cycles_per_ms() 0 +#endif + // // cache line // @@ -3053,7 +3063,7 @@ void ggml_compute_forward_mul_mat_f32( const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - int64_t t0 = ggml_time_us(); + int64_t t0 = ggml_perf_time_us(); UNUSED(t0); const int ne00 = src0->ne[0]; @@ -3232,7 +3242,7 @@ void ggml_compute_forward_mul_mat_f32( } } - //int64_t t1 = ggml_time_us(); + //int64_t t1 = ggml_perf_time_us(); //static int64_t acc = 0; //acc += t1 - t0; //if (t1 - t0 > 10) { @@ -3251,7 +3261,7 @@ void ggml_compute_forward_mul_mat_f16_f32( const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - int64_t t0 = ggml_time_us(); + int64_t t0 = ggml_perf_time_us(); UNUSED(t0); const int ne00 = src0->ne[0]; @@ -4619,8 +4629,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } } - const int64_t perf_start_cycles = ggml_cycles(); - const int64_t perf_start_time_us = ggml_time_us(); + const int64_t perf_start_cycles = ggml_perf_cycles(); + const int64_t perf_start_time_us = ggml_perf_time_us(); for (int i = 0; i < cgraph->n_nodes; i++) { GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, i, cgraph->n_nodes); @@ -4632,8 +4642,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) // continue; //} - const int64_t perf_node_start_cycles = ggml_cycles(); - const int64_t perf_node_start_time_us = ggml_time_us(); + const int64_t perf_node_start_cycles = ggml_perf_cycles(); + const int64_t perf_node_start_time_us = ggml_perf_time_us(); // INIT struct ggml_compute_params params = { @@ -4706,8 +4716,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) // performance stats (node) { - int64_t perf_cycles_cur = ggml_cycles() - perf_node_start_cycles; - int64_t perf_time_us_cur = ggml_time_us() - perf_node_start_time_us; + int64_t perf_cycles_cur = ggml_perf_cycles() - perf_node_start_cycles; + int64_t perf_time_us_cur = ggml_perf_time_us() - perf_node_start_time_us; node->perf_runs++; node->perf_cycles += perf_cycles_cur; @@ -4731,8 +4741,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) // performance stats (graph) { - int64_t perf_cycles_cur = ggml_cycles() - perf_start_cycles; - int64_t perf_time_us_cur = ggml_time_us() - perf_start_time_us; + int64_t perf_cycles_cur = ggml_perf_cycles() - perf_start_cycles; + int64_t perf_time_us_cur = ggml_perf_time_us() - perf_start_time_us; cgraph->perf_runs++; cgraph->perf_cycles += perf_cycles_cur;