]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : 4-bit Integer quantisation + many llama.cpp improvements (#27)
authorGeorgi Gerganov <redacted>
Wed, 29 Mar 2023 19:21:36 +0000 (22:21 +0300)
committerGitHub <redacted>
Wed, 29 Mar 2023 19:21:36 +0000 (22:21 +0300)
* gq : attempt at n-bit quantization

* gq : add amax based method 3

* gq : progress on method 2

* gq : method 4 (AVX2)

* gq : method 4 (ARM)

* gq : method 4 (AVX2 attempt) + method 5 (no min)

* gq : method 5 (ARM)

* gpt-2 : model conversion for Q4_0 quantization

* ggml : Q4_0 quantization support (ggml_get_rows())

* gpt-2 : loading Q4_0 quantized model

* ggml : q4_0 quantization support

* ggml : q4_1 quantization support (seems to work for bigger models)

* gpt-2 : add gpt-2-quantize tool for quantizing f32 GPT-2 models

* ggml : 4-bit quantization works (only scalar for now)

* gq : add method 6 (ARM)

* ggml : vectorized mad q4_0 (ARM)

* ggml : vectorized quantize_row_q4_0 (ARM)

* ggml : simplify mad q4_0 (ARM)

* ggml : minor indentations

* gpt-j : support for 4-bit quantized model inference

* ggml : GGML_ASSERT() instead of assert() where appropriate

* gpt : avoid ggml_transpose on model tensors (new models!)

* gpt-2 : minor

* gpt-j : fix conversion for FP16 models (such as GPT-JT-6B)

* ggml : add ggml_compute_forward_rope_f16()

* gpt : fix memory usage computation

* ggml : fix ggml_is_contiguous() to take into account blck size

* whisper : add whisper-qunatize tool

* whisper : add support for quantized models

* whisper : mem usage based on model format type

* gpt : seems not worth to use FP16 for KV cache

* gpt : support quantisation of f16 models files

* ggml : fixes for rpi4

* whisper : add Q4_1 model sizes

* ggml : add WASM SIMD for Q4_0

* utils : print quantization histograms

* ggml : sync all changes from llama.cpp and whisper.cpp

* ggml : finalize the Q4_1 quantization for ARM_NEON

20 files changed:
CMakeLists.txt
examples/gpt-2/CMakeLists.txt
examples/gpt-2/README.md
examples/gpt-2/convert-ckpt-to-ggml.py
examples/gpt-2/main.cpp
examples/gpt-2/quantize.cpp [new file with mode: 0644]
examples/gpt-j/CMakeLists.txt
examples/gpt-j/convert-h5-to-ggml.py
examples/gpt-j/main.cpp
examples/gpt-j/quantize.cpp [new file with mode: 0644]
examples/utils.h
examples/whisper/CMakeLists.txt
examples/whisper/convert-pt-to-ggml.py
examples/whisper/main.cpp
examples/whisper/quantize.cpp [new file with mode: 0644]
examples/whisper/whisper.cpp
examples/whisper/whisper.h
include/ggml/ggml.h
src/ggml.c
tests/test-mul-mat2.c

index d88c5b101348b8aaef6cd87229474934488268dc..54d18b0031b2a8993c8706cce111bf2848a44b82 100644 (file)
@@ -31,7 +31,7 @@ option(GGML_NO_ACCELERATE           "ggml: disable Accelerate framework" OFF)
 # sanitizers
 
 if (GGML_SANITIZE_THREAD)
-    set(CMAKE_C_FLAGS   "${CMAKE_C_FLAGS} -fsanitize=thread")
+    set(CMAKE_C_FLAGS   "${CMAKE_C_FLAGS}   -fsanitize=thread")
     set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=thread")
 endif()
 
index 9960cfe81995b216737851aed3b87eedef43a92b..3b7ab5efe771f42a5f9b0c3fa9b2df42de48e35a 100644 (file)
@@ -4,3 +4,10 @@
 set(TEST_TARGET gpt-2)
 add_executable(${TEST_TARGET} main.cpp)
 target_link_libraries(${TEST_TARGET} PRIVATE ggml ggml_utils)
+
+#
+# gpt-2-quantize
+
+set(TEST_TARGET gpt-2-quantize)
+add_executable(${TEST_TARGET} quantize.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml ggml_utils)
index 60fea55dc759248172ff2e5d2b97cb239499edfe..ed61f861f678750e4ab947a4ddb979ba44e467e9 100644 (file)
@@ -94,7 +94,7 @@ Done! Model '117M' saved in 'models/gpt-2-117M/'
 
 Run the convert-ckpt-to-ggml.py script to convert the model to ggml format.
 
-  python /Users/john/ggml/examples/gpt-2/convert-ckpt-to-ggml.py models/gpt-2-117M/
+  python /Users/john/ggml/examples/gpt-2/convert-ckpt-to-ggml.py models/gpt-2-117M/ 1
 
 ```
 
index 7ae438013a08341431a67af78c76526f9408f952..60cd963d21e27000f971f672f0ead3c96293953a 100644 (file)
@@ -45,8 +45,18 @@ def bytes_to_unicode():
     cs = [chr(n) for n in cs]
     return dict(zip(bs, cs))
 
-if len(sys.argv) < 2:
-    print("Usage: convert-ckpt-to-ggml.py dir-model [use-f32]\n")
+# helper method to convert a numpy array to different float types
+def convert_to_ftype(data, ftype):
+    # fp16
+    if ftype == 1:
+        return data.astype(np.float16)
+
+    assert False, "Invalid ftype: " + str(ftype)
+
+if len(sys.argv) < 3:
+    print("Usage: convert-ckpt-to-ggml.py dir-model ftype\n")
+    print("  ftype == 0 -> float32")
+    print("  ftype == 1 -> float16")
     sys.exit(1)
 
 # output in the same directory as the model
@@ -59,11 +69,20 @@ with open(dir_model + "/encoder.json", "r") as f:
 with open(dir_model + "/hparams.json", "r") as f:
     hparams = json.load(f)
 
-# use 16-bit or 32-bit floats
-use_f16 = True
+# possible data types
+#   ftype == 0 -> float32
+#   ftype == 1 -> float16
+#
+# map from ftype to string
+ftype_str = ["f32", "f16"]
+
+ftype = 1
 if len(sys.argv) > 2:
-    use_f16 = False
-    fname_out = sys.argv[1] + "/ggml-model-f32.bin"
+    ftype = int(sys.argv[2])
+    if ftype < 0 or ftype > 1:
+        print("Invalid ftype: " + str(ftype))
+        sys.exit(1)
+    fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"
 
 list_vars = tf.train.list_variables(dir_model)
 
@@ -75,7 +94,7 @@ fout.write(struct.pack("i", hparams["n_ctx"]))
 fout.write(struct.pack("i", hparams["n_embd"]))
 fout.write(struct.pack("i", hparams["n_head"]))
 fout.write(struct.pack("i", hparams["n_layer"]))
-fout.write(struct.pack("i", use_f16))
+fout.write(struct.pack("i", ftype))
 
 byte_encoder = bytes_to_unicode()
 byte_decoder = {v:k for k, v in byte_encoder.items()}
@@ -93,9 +112,22 @@ for name, shape in list_vars:
     data = tf.train.load_variable(dir_model, name).squeeze()
     n_dims = len(data.shape);
 
-    # ftype == 0 -> float32, ftype == 1 -> float16
-    ftype = 0;
-    if use_f16:
+    # for efficiency - transpose the projection matrices
+    # "model/h.*/attn/c_attn/w"
+    # "model/h.*/attn/c_proj/w"
+    # "model/h.*/mlp/c_fc/w"
+    # "model/h.*/mlp/c_proj/w"
+    if name[-14:] == "/attn/c_attn/w" or \
+       name[-14:] == "/attn/c_proj/w" or \
+       name[-11:] == "/mlp/c_fc/w" or \
+       name[-13:] == "/mlp/c_proj/w":
+        print("  Transposing")
+        data = data.transpose()
+
+    dshape = data.shape
+
+    ftype_cur = 0
+    if ftype != 0:
         # match name:
         #  "model/wte"
         #  "model/h.*/attn/c_attn/w"
@@ -103,24 +135,19 @@ for name, shape in list_vars:
         #  "model/h.*/mlp/c_fc/w"
         #  "model/h.*/mlp/c_proj/w"
         if name == "model/wte" or name[-2:] == "/w":
-            print("  Converting to float16")
-            data = data.astype(np.float16)
-            ftype = 1
+            print("  Converting to " + ftype_str[ftype])
+            data = convert_to_ftype(data, ftype)
+            ftype_cur = ftype
         else:
             print("  Converting to float32")
             data = data.astype(np.float32)
-            ftype = 0
-
-    # for efficiency - transpose the projection matrices
-    if name[-13:] == "/mlp/c_proj/w":
-        print("  Transposing")
-        data = data.transpose()
+            ftype_cur = 0
 
     # header
     str = name.encode('utf-8')
-    fout.write(struct.pack("iii", n_dims, len(str), ftype))
+    fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
     for i in range(n_dims):
-        fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
+        fout.write(struct.pack("i", dshape[n_dims - 1 - i]))
     fout.write(str);
 
     # data
index 5a0ab01fe5b94cbab0126b49e1a5d5c5c99abf02..f371bc66735d821e2b2e139c44f62d0e9275c37d 100644 (file)
@@ -42,7 +42,7 @@ struct gpt2_layer {
     struct ggml_tensor * c_mlp_fc_w;
     struct ggml_tensor * c_mlp_fc_b;
 
-    struct ggml_tensor * c_mlp_proj_w_trans; // transposed for efficiency
+    struct ggml_tensor * c_mlp_proj_w;
     struct ggml_tensor * c_mlp_proj_b;
 };
 
@@ -130,9 +130,23 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
         }
     }
 
-    // for the big tensors, we have the option to store the data in 16-bit floats
+    // for the big tensors, we have the option to store the data in 16-bit floats or quantized
     // in order to save memory and also to speed up the computation
-    const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
+    ggml_type wtype = GGML_TYPE_COUNT;
+    switch (model.hparams.f16) {
+        case 0: wtype = GGML_TYPE_F32;  break;
+        case 1: wtype = GGML_TYPE_F16;  break;
+        case 2: wtype = GGML_TYPE_Q4_0; break;
+        case 3: wtype = GGML_TYPE_Q4_1; break;
+        default:
+                {
+                    fprintf(stderr, "%s: invalid model file '%s' (bad f16 value %d)\n",
+                            __func__, fname.c_str(), model.hparams.f16);
+                    return false;
+                }
+    }
+
+    const ggml_type wtype2 = GGML_TYPE_F32;
 
     auto & ctx = model.ctx;
 
@@ -146,32 +160,32 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
         const int n_ctx   = hparams.n_ctx;
         const int n_vocab = hparams.n_vocab;
 
-        ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_g
-        ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_b
+        ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
+        ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
 
-        ctx_size += n_vocab*n_embd*ggml_type_size(wtype);         // wte
-        ctx_size +=   n_ctx*n_embd*ggml_type_size(GGML_TYPE_F32); // wpe
+        ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype);         // wte
+        ctx_size +=   n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
 
-        ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_g
-        ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_b
+        ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
+        ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
 
-        ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_g
-        ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_b
+        ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
+        ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
 
-        ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_size(wtype));         // c_attn_attn_w
-        ctx_size += n_layer*(       3*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_attn_b
+        ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype));         // c_attn_attn_w
+        ctx_size += n_layer*(       3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
 
-        ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype));           // c_attn_proj_w
-        ctx_size += n_layer*(       n_embd*ggml_type_size(GGML_TYPE_F32));   // c_attn_proj_b
+        ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype));           // c_attn_proj_w
+        ctx_size += n_layer*(       n_embd*ggml_type_sizef(GGML_TYPE_F32));   // c_attn_proj_b
 
-        ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype));         // c_mlp_fc_w
-        ctx_size += n_layer*(       4*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_fc_b
+        ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype));         // c_mlp_fc_w
+        ctx_size += n_layer*(       4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
 
-        ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype));         // c_mlp_proj_w
-        ctx_size += n_layer*(         n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_proj_b
+        ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype));         // c_mlp_proj_w
+        ctx_size += n_layer*(         n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
 
-        ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_k
-        ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_v
+        ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
+        ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
 
         ctx_size += (6 + 12*n_layer)*256; // object overhead
 
@@ -219,23 +233,23 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
         for (int i = 0; i < n_layer; ++i) {
             auto & layer = model.layers[i];
 
-            layer.ln_1_g             = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
-            layer.ln_1_b             = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+            layer.ln_1_g        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+            layer.ln_1_b        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
 
-            layer.ln_2_g             = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
-            layer.ln_2_b             = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+            layer.ln_2_g        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+            layer.ln_2_b        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
 
-            layer.c_attn_attn_w      = ggml_new_tensor_2d(ctx, wtype,         3*n_embd, n_embd);
-            layer.c_attn_attn_b      = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
+            layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype,           n_embd, 3*n_embd);
+            layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
 
-            layer.c_attn_proj_w      = ggml_new_tensor_2d(ctx, wtype,           n_embd, n_embd);
-            layer.c_attn_proj_b      = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+            layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype,           n_embd, n_embd);
+            layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
 
-            layer.c_mlp_fc_w         = ggml_new_tensor_2d(ctx, wtype,         4*n_embd, n_embd);
-            layer.c_mlp_fc_b         = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
+            layer.c_mlp_fc_w    = ggml_new_tensor_2d(ctx, wtype,           n_embd, 4*n_embd);
+            layer.c_mlp_fc_b    = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
 
-            layer.c_mlp_proj_w_trans = ggml_new_tensor_2d(ctx, wtype,         4*n_embd, n_embd);
-            layer.c_mlp_proj_b       = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+            layer.c_mlp_proj_w  = ggml_new_tensor_2d(ctx, wtype,         4*n_embd, n_embd);
+            layer.c_mlp_proj_b  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
 
             // map by name
             model.tensors["model/h" + std::to_string(i) + "/ln_1/g"]        = layer.ln_1_g;
@@ -253,7 +267,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
             model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"]    = layer.c_mlp_fc_w;
             model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"]    = layer.c_mlp_fc_b;
 
-            model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"]  = layer.c_mlp_proj_w_trans;
+            model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"]  = layer.c_mlp_proj_w;
             model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"]  = layer.c_mlp_proj_b;
         }
     }
@@ -321,9 +335,26 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
                 return false;
             }
 
-            const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
+            if (0) {
+                static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
+                printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ftype_str[ftype], ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
+            }
 
-            if (nelements*bpe != ggml_nbytes(tensor)) {
+            size_t bpe = 0;
+
+            switch (ftype) {
+                case 0: bpe = ggml_type_size(GGML_TYPE_F32);  break;
+                case 1: bpe = ggml_type_size(GGML_TYPE_F16);  break;
+                case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break;
+                case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break;
+                default:
+                        {
+                            fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype);
+                            return false;
+                        }
+            };
+
+            if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
                 fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
                         __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
                 return false;
@@ -331,7 +362,6 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
 
             fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
 
-            //printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
             total_size += ggml_nbytes(tensor);
         }
 
@@ -433,7 +463,7 @@ bool gpt2_eval(
         // [2304, N]
         {
             cur = ggml_mul_mat(ctx0,
-                    ggml_transpose(ctx0, model.layers[il].c_attn_attn_w),
+                    model.layers[il].c_attn_attn_w,
                     cur);
 
             cur = ggml_add(ctx0,
@@ -509,11 +539,13 @@ bool gpt2_eval(
             // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
             // [n_past + N, 64, 12]
             struct ggml_tensor * V_trans =
-                ggml_permute(ctx0,
-                        ggml_reshape_3d(ctx0,
-                            ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
-                            n_embd/n_head, n_head, n_past + N),
-                        1, 2, 0, 3);
+                ggml_cpy(ctx0,
+                        ggml_permute(ctx0,
+                            ggml_reshape_3d(ctx0,
+                                ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
+                                n_embd/n_head, n_head, n_past + N),
+                            1, 2, 0, 3),
+                        ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head));
 
             // KQV = transpose(V) * KQ_soft_max
             // [64, N, 12]
@@ -540,7 +572,7 @@ bool gpt2_eval(
         // [768, N]
         {
             cur = ggml_mul_mat(ctx0,
-                    ggml_transpose(ctx0, model.layers[il].c_attn_proj_w),
+                    model.layers[il].c_attn_proj_w,
                     cur);
 
             cur = ggml_add(ctx0,
@@ -577,7 +609,7 @@ bool gpt2_eval(
             // cur = fc_w*cur + fc_b
             // [3072, N]
             cur = ggml_mul_mat(ctx0,
-                    ggml_transpose(ctx0, model.layers[il].c_mlp_fc_w),
+                    model.layers[il].c_mlp_fc_w,
                     cur);
 
             cur = ggml_add(ctx0,
@@ -597,7 +629,7 @@ bool gpt2_eval(
             // cur = proj_w*cur + proj_b
             // [768, N]
             cur = ggml_mul_mat(ctx0,
-                    model.layers[il].c_mlp_proj_w_trans,
+                    model.layers[il].c_mlp_proj_w,
                     cur);
 
             cur = ggml_add(ctx0,
@@ -714,8 +746,12 @@ int main(int argc, char ** argv) {
 
     params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
 
-    printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
-    printf("\n");
+    printf("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
+    printf("%s: number of tokens in prompt = %zu, first 8 tokens: ", __func__, embd_inp.size());
+    for (int i = 0; i < std::min(8, (int) embd_inp.size()); i++) {
+        printf("%d ", embd_inp[i]);
+    }
+    printf("\n\n");
 
     // submit the input prompt token-by-token
     // this reduces the memory usage during inference, at the cost of a bit of speed at the beginning
diff --git a/examples/gpt-2/quantize.cpp b/examples/gpt-2/quantize.cpp
new file mode 100644 (file)
index 0000000..3cc48ea
--- /dev/null
@@ -0,0 +1,322 @@
+#include "ggml/ggml.h"
+
+#include "utils.h"
+
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <fstream>
+#include <map>
+#include <string>
+#include <vector>
+#include <regex>
+
+// default hparams (GPT-2 117M)
+struct gpt2_hparams {
+    int32_t n_vocab = 50257;
+    int32_t n_ctx   = 1024;
+    int32_t n_embd  = 768;
+    int32_t n_head  = 12;
+    int32_t n_layer = 12;
+    int32_t f16     = 1;
+};
+
+// quantize a model
+bool gpt2_model_quantize(const std::string & fname_inp, const std::string & fname_out, int itype) {
+    ggml_type type = GGML_TYPE_Q4_1;
+
+    switch (itype) {
+        case 2: type = GGML_TYPE_Q4_0; break;
+        case 3: type = GGML_TYPE_Q4_1; break;
+        default: fprintf(stderr, "%s: invalid quantization type %d\n", __func__, itype); return 1;
+    };
+
+    if (type != GGML_TYPE_Q4_0 && type != GGML_TYPE_Q4_1) {
+        fprintf(stderr, "%s: invalid quantization type %d\n", __func__, type);
+        return false;
+    }
+
+    gpt_vocab vocab;
+
+    printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
+
+    auto finp = std::ifstream(fname_inp, std::ios::binary);
+    if (!finp) {
+        fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str());
+        return false;
+    }
+
+    auto fout = std::ofstream(fname_out, std::ios::binary);
+    if (!fout) {
+        fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str());
+        return false;
+    }
+
+    // verify magic
+    {
+        uint32_t magic;
+        finp.read((char *) &magic, sizeof(magic));
+        if (magic != 0x67676d6c) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
+            return false;
+        }
+
+        fout.write((char *) &magic, sizeof(magic));
+    }
+
+    gpt2_hparams hparams;
+
+    // load hparams
+    {
+        finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
+        finp.read((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));
+        finp.read((char *) &hparams.n_embd,  sizeof(hparams.n_embd));
+        finp.read((char *) &hparams.n_head,  sizeof(hparams.n_head));
+        finp.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
+        finp.read((char *) &hparams.f16,     sizeof(hparams.f16));
+
+        printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
+        printf("%s: n_ctx   = %d\n", __func__, hparams.n_ctx);
+        printf("%s: n_embd  = %d\n", __func__, hparams.n_embd);
+        printf("%s: n_head  = %d\n", __func__, hparams.n_head);
+        printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
+        printf("%s: f16     = %d\n", __func__, hparams.f16);
+
+        fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
+        fout.write((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));
+        fout.write((char *) &hparams.n_embd,  sizeof(hparams.n_embd));
+        fout.write((char *) &hparams.n_head,  sizeof(hparams.n_head));
+        fout.write((char *) &hparams.n_layer, sizeof(hparams.n_layer));
+        fout.write((char *) &itype,           sizeof(hparams.f16));
+    }
+
+    // load vocab
+    {
+        int32_t n_vocab = 0;
+        finp.read ((char *) &n_vocab, sizeof(n_vocab));
+        fout.write((char *) &n_vocab, sizeof(n_vocab));
+
+        if (n_vocab != hparams.n_vocab) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
+                    __func__, fname_inp.c_str(), n_vocab, hparams.n_vocab);
+            return false;
+        }
+
+        std::string word;
+        for (int i = 0; i < n_vocab; i++) {
+            uint32_t len;
+            finp.read ((char *) &len, sizeof(len));
+            fout.write((char *) &len, sizeof(len));
+
+            word.resize(len);
+            finp.read ((char *) word.data(), len);
+            fout.write((char *) word.data(), len);
+
+            vocab.token_to_id[word] = i;
+            vocab.id_to_token[i] = word;
+        }
+    }
+
+    // load weights
+    {
+        size_t total_size_org = 0;
+        size_t total_size_new = 0;
+
+        std::vector<float> work;
+
+        std::vector<uint8_t>     data_u8;
+        std::vector<ggml_fp16_t> data_f16;
+        std::vector<float>       data_f32;
+
+        std::vector<int64_t> hist_all(1 << 4, 0);
+
+        while (true) {
+            int32_t n_dims;
+            int32_t length;
+            int32_t ftype;
+
+            finp.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+            finp.read(reinterpret_cast<char *>(&length), sizeof(length));
+            finp.read(reinterpret_cast<char *>(&ftype),  sizeof(ftype));
+
+            if (finp.eof()) {
+                break;
+            }
+
+            int32_t nelements = 1;
+            int32_t ne[2] = { 1, 1 };
+            for (int i = 0; i < n_dims; ++i) {
+                finp.read (reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+                nelements *= ne[i];
+            }
+
+            std::string name(length, 0);
+            finp.read (&name[0], length);
+
+            {
+                static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
+                printf("%24s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ftype_str[ftype]);
+            }
+
+            // regexes of tensor names to be quantized
+            const std::vector<std::string> k_names = {
+                "model/wte",
+                "model/h.*/attn/c_attn/w",
+                "model/h.*/attn/c_proj/w",
+                "model/h.*/mlp/c_fc/w",
+                "model/h.*/mlp/c_proj/w",
+            };
+
+            bool quantize = false;
+            for (const auto & s : k_names) {
+                if (std::regex_match(name, std::regex(s))) {
+                    quantize = true;
+                    break;
+                }
+            }
+
+            if (quantize) {
+                if (ftype != 0 && ftype != 1) {
+                    fprintf(stderr, "%s: unsupported ftype %d for integer quantization\n", __func__, ftype);
+                    return false;
+                }
+
+                if (ftype == 1) {
+                    data_f16.resize(nelements);
+                    finp.read(reinterpret_cast<char *>(data_f16.data()), nelements * sizeof(ggml_fp16_t));
+                    data_f32.resize(nelements);
+                    for (int i = 0; i < nelements; ++i) {
+                        data_f32[i] = ggml_fp16_to_fp32(data_f16[i]);
+                    }
+                } else {
+                    data_f32.resize(nelements);
+                    finp.read(reinterpret_cast<char *>(data_f32.data()), nelements * sizeof(float));
+                }
+
+                ftype = itype;
+            } else {
+                const int bpe = (ftype == 0) ? sizeof(float) : sizeof(uint16_t);
+
+                data_u8.resize(nelements*bpe);
+                finp.read(reinterpret_cast<char *>(data_u8.data()), nelements * bpe);
+            }
+
+            fout.write(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+            fout.write(reinterpret_cast<char *>(&length), sizeof(length));
+            fout.write(reinterpret_cast<char *>(&ftype),  sizeof(ftype));
+            for (int i = 0; i < n_dims; ++i) {
+                fout.write(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+            }
+            fout.write(&name[0], length);
+
+            if (quantize) {
+                printf("quantizing .. ");
+                work.resize(nelements); // for quantization
+
+                size_t cur_size = 0;
+                std::vector<int64_t> hist_cur(1 << 4, 0);
+
+                switch (type) {
+                    case GGML_TYPE_Q4_0:
+                        {
+                            cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
+                        } break;
+                    case GGML_TYPE_Q4_1:
+                        {
+                            cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
+                        } break;
+                    default:
+                        {
+                            fprintf(stderr, "%s: unsupported quantization type %d\n", __func__, type);
+                            return false;
+                        }
+                }
+
+                fout.write(reinterpret_cast<char *>(work.data()), cur_size);
+                total_size_new += cur_size;
+
+                printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float)/1024.0/1024.0, cur_size/1024.0/1024.0);
+                for (int i = 0; i < hist_cur.size(); ++i) {
+                    hist_all[i] += hist_cur[i];
+                }
+
+                for (int i = 0; i < hist_cur.size(); ++i) {
+                    printf("%5.3f ", hist_cur[i] / (float)nelements);
+                }
+                printf("\n");
+            } else {
+                printf("size = %8.3f MB\n", data_u8.size()/1024.0/1024.0);
+                fout.write(reinterpret_cast<char *>(data_u8.data()), data_u8.size());
+                total_size_new += data_u8.size();
+            }
+
+            total_size_org += nelements * sizeof(float);
+        }
+
+        printf("%s: model size  = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
+        printf("%s: quant size  = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0);
+
+        {
+            int64_t sum_all = 0;
+            for (int i = 0; i < hist_all.size(); ++i) {
+                sum_all += hist_all[i];
+            }
+
+            printf("%s: hist: ", __func__);
+            for (int i = 0; i < hist_all.size(); ++i) {
+                printf("%5.3f ", hist_all[i] / (float)sum_all);
+            }
+            printf("\n");
+        }
+    }
+
+    finp.close();
+    fout.close();
+
+    return true;
+}
+
+// usage:
+//  ./gpt-2-quantize models/gpt-2-117M/ggml-model.bin models/gpt-2-117M/ggml-model-quant.bin type
+//
+int main(int argc, char ** argv) {
+    if (argc != 4) {
+        fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]);
+        fprintf(stderr, "  type = 2 - q4_0\n");
+        fprintf(stderr, "  type = 3 - q4_1\n");
+        return 1;
+    }
+
+    const std::string fname_inp = argv[1];
+    const std::string fname_out = argv[2];
+
+    const int itype = atoi(argv[3]);
+
+    const int64_t t_main_start_us = ggml_time_us();
+
+    int64_t t_quantize_us = 0;
+
+    // load the model
+    {
+        const int64_t t_start_us = ggml_time_us();
+
+        if (!gpt2_model_quantize(fname_inp, fname_out, itype)) {
+            fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
+            return 1;
+        }
+
+        t_quantize_us = ggml_time_us() - t_start_us;
+    }
+
+    // report timing
+    {
+        const int64_t t_main_end_us = ggml_time_us();
+
+        printf("\n");
+        printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0f);
+        printf("%s:    total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
+    }
+
+    return 0;
+}
index 4199a3faed54a42dcd8eae98c31b33dbcf46a7aa..390746d50e540af9fe5d546296a4519c9f8f3782 100644 (file)
@@ -4,3 +4,10 @@
 set(TEST_TARGET gpt-j)
 add_executable(${TEST_TARGET} main.cpp)
 target_link_libraries(${TEST_TARGET} PRIVATE ggml ggml_utils)
+
+#
+# gpt-j-quantize
+
+set(TEST_TARGET gpt-j-quantize)
+add_executable(${TEST_TARGET} quantize.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml ggml_utils)
index 310e60e0a06a269bbd6f43255ee75a4031a01452..e254f2cc3d50493e9376d1c01a3b28645a6ae5f1 100644 (file)
@@ -47,8 +47,10 @@ def bytes_to_unicode():
     cs = [chr(n) for n in cs]
     return dict(zip(bs, cs))
 
-if len(sys.argv) < 2:
+if len(sys.argv) < 3:
     print("Usage: convert-h5-to-ggml.py dir-model [use-f32]\n")
+    print("  ftype == 0 -> float32")
+    print("  ftype == 1 -> float16")
     sys.exit(1)
 
 # output in the same directory as the model
@@ -64,11 +66,21 @@ with open(dir_model + "/added_tokens.json", "r") as f:
 with open(dir_model + "/config.json", "r") as f:
     hparams = json.load(f)
 
-# use 16-bit or 32-bit floats
-use_f16 = True
+# possible data types
+#   ftype == 0 -> float32
+#   ftype == 1 -> float16
+#
+# map from ftype to string
+ftype_str = ["f32", "f16"]
+
+ftype = 1
 if len(sys.argv) > 2:
-    use_f16 = False
-    fname_out = sys.argv[1] + "/ggml-model-f32.bin"
+    ftype = int(sys.argv[2])
+    if ftype < 0 or ftype > 1:
+        print("Invalid ftype: " + str(ftype))
+        sys.exit(1)
+    fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"
+
 
 model = GPTJForCausalLM.from_pretrained(dir_model, low_cpu_mem_usage=True)
 #print (model)
@@ -85,7 +97,7 @@ fout.write(struct.pack("i", hparams["n_embd"]))
 fout.write(struct.pack("i", hparams["n_head"]))
 fout.write(struct.pack("i", hparams["n_layer"]))
 fout.write(struct.pack("i", hparams["rotary_dim"]))
-fout.write(struct.pack("i", use_f16))
+fout.write(struct.pack("i", ftype))
 
 byte_encoder = bytes_to_unicode()
 byte_decoder = {v:k for k, v in byte_encoder.items()}
@@ -114,34 +126,40 @@ for name in list_vars.keys():
     n_dims = len(data.shape);
 
     # ftype == 0 -> float32, ftype == 1 -> float16
-    ftype = 0;
-    if use_f16:
+    ftype_cur = 0;
+    if ftype != 0:
         if name[-7:] == ".weight" and n_dims == 2:
             print("  Converting to float16")
             data = data.astype(np.float16)
-            ftype = 1
+            ftype_cur = 1
         else:
             print("  Converting to float32")
             data = data.astype(np.float32)
-            ftype = 0
+            ftype_cur = 0
+    else:
+        if data.dtype != np.float32:
+            print("  Converting to float32")
+            data = data.astype(np.float32)
+            ftype_cur = 0
 
     # for efficiency - transpose these matrices:
-    #  "transformer.h.*.mlp.fc_in.weight
-    #  "transformer.h.*.attn.out_proj.weight
+    # (note - with latest ggml this is no longer more efficient, so disabling it)
+    #  "transformer.h.*.mlp.fc_in.weight"
+    #  "transformer.h.*.attn.out_proj.weight"
     #  "transformer.h.*.attn.q_proj.weight"
     #  "transformer.h.*.attn.k_proj.weight"
     #  "transformer.h.*.attn.v_proj.weight"
-    if name.endswith(".mlp.fc_in.weight")     or \
-       name.endswith(".attn.out_proj.weight") or \
-       name.endswith(".attn.q_proj.weight")   or \
-       name.endswith(".attn.k_proj.weight")   or \
-       name.endswith(".attn.v_proj.weight"):
-        print("  Transposing")
-        data = data.transpose()
+    #if name.endswith(".mlp.fc_in.weight")     or \
+    #   name.endswith(".attn.out_proj.weight") or \
+    #   name.endswith(".attn.q_proj.weight")   or \
+    #   name.endswith(".attn.k_proj.weight")   or \
+    #   name.endswith(".attn.v_proj.weight"):
+    #    print("  Transposing")
+    #    data = data.transpose()
 
     # header
     str = name.encode('utf-8')
-    fout.write(struct.pack("iii", n_dims, len(str), ftype))
+    fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
     for i in range(n_dims):
         fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
     fout.write(str);
index 5c6ce93a081f98bc0bf42cfc3abd120e1d0785da..c059e1b72caf32ad8b8386bb1bb464af62de4160 100644 (file)
@@ -40,7 +40,7 @@ struct gptj_layer {
     struct ggml_tensor * c_mlp_fc_w;
     struct ggml_tensor * c_mlp_fc_b;
 
-    struct ggml_tensor * c_mlp_proj_w_trans;
+    struct ggml_tensor * c_mlp_proj_w;
     struct ggml_tensor * c_mlp_proj_b;
 };
 
@@ -132,9 +132,23 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab &
         }
     }
 
-    // for the big tensors, we have the option to store the data in 16-bit floats
+    // for the big tensors, we have the option to store the data in 16-bit floats or quantized
     // in order to save memory and also to speed up the computation
-    const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
+    ggml_type wtype = GGML_TYPE_COUNT;
+    switch (model.hparams.f16) {
+        case 0: wtype = GGML_TYPE_F32;  break;
+        case 1: wtype = GGML_TYPE_F16;  break;
+        case 2: wtype = GGML_TYPE_Q4_0; break;
+        case 3: wtype = GGML_TYPE_Q4_1; break;
+        default:
+                {
+                    fprintf(stderr, "%s: invalid model file '%s' (bad f16 value %d)\n",
+                            __func__, fname.c_str(), model.hparams.f16);
+                    return false;
+                }
+    }
+
+    const ggml_type wtype2 = GGML_TYPE_F32;
 
     auto & ctx = model.ctx;
 
@@ -148,31 +162,31 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab &
         const int n_ctx   = hparams.n_ctx;
         const int n_vocab = hparams.n_vocab;
 
-        ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_g
-        ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_b
+        ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
+        ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
 
-        ctx_size += n_embd*n_vocab*ggml_type_size(wtype); // wte
+        ctx_size += n_embd*n_vocab*ggml_type_sizef(wtype); // wte
 
-        ctx_size += n_embd*n_vocab*ggml_type_size(wtype);         // lmh_g
-        ctx_size +=        n_vocab*ggml_type_size(GGML_TYPE_F32); // lmh_b
+        ctx_size += n_embd*n_vocab*ggml_type_sizef(wtype);         // lmh_g
+        ctx_size +=        n_vocab*ggml_type_sizef(GGML_TYPE_F32); // lmh_b
 
-        ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_g
-        ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_b
+        ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
+        ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
 
-        ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_q_proj_w
-        ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_k_proj_w
-        ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_v_proj_w
+        ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_q_proj_w
+        ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_k_proj_w
+        ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_v_proj_w
 
-        ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_proj_w
+        ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
 
-        ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype));         // c_mlp_fc_w
-        ctx_size += n_layer*(       4*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_fc_b
+        ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype));         // c_mlp_fc_w
+        ctx_size += n_layer*(       4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
 
-        ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype));         // c_mlp_proj_w_trans
-        ctx_size += n_layer*(         n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_proj_b
+        ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype));         // c_mlp_proj_w
+        ctx_size += n_layer*(         n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
 
-        ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_k
-        ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_v
+        ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
+        ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
 
         ctx_size += (5 + 10*n_layer)*256; // object overhead
 
@@ -224,20 +238,20 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab &
         for (int i = 0; i < n_layer; ++i) {
             auto & layer = model.layers[i];
 
-            layer.ln_1_g                = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
-            layer.ln_1_b                = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+            layer.ln_1_g          = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+            layer.ln_1_b          = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
 
-            layer.c_attn_q_proj_w       = ggml_new_tensor_2d(ctx, wtype,           n_embd,   n_embd);
-            layer.c_attn_k_proj_w       = ggml_new_tensor_2d(ctx, wtype,           n_embd,   n_embd);
-            layer.c_attn_v_proj_w       = ggml_new_tensor_2d(ctx, wtype,           n_embd,   n_embd);
+            layer.c_attn_q_proj_w = ggml_new_tensor_2d(ctx, wtype,           n_embd,   n_embd);
+            layer.c_attn_k_proj_w = ggml_new_tensor_2d(ctx, wtype,           n_embd,   n_embd);
+            layer.c_attn_v_proj_w = ggml_new_tensor_2d(ctx, wtype,           n_embd,   n_embd);
 
-            layer.c_attn_proj_w         = ggml_new_tensor_2d(ctx, wtype,           n_embd,   n_embd);
+            layer.c_attn_proj_w   = ggml_new_tensor_2d(ctx, wtype,           n_embd,   n_embd);
 
-            layer.c_mlp_fc_w            = ggml_new_tensor_2d(ctx, wtype,         4*n_embd,   n_embd);
-            layer.c_mlp_fc_b            = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
+            layer.c_mlp_fc_w      = ggml_new_tensor_2d(ctx, wtype,           n_embd, 4*n_embd);
+            layer.c_mlp_fc_b      = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
 
-            layer.c_mlp_proj_w_trans    = ggml_new_tensor_2d(ctx, wtype,         4*n_embd,   n_embd);
-            layer.c_mlp_proj_b          = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+            layer.c_mlp_proj_w    = ggml_new_tensor_2d(ctx, wtype,         4*n_embd,   n_embd);
+            layer.c_mlp_proj_b    = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
 
             // map by name
             model.tensors["transformer.h." + std::to_string(i) + ".ln_1.weight"]          = layer.ln_1_g;
@@ -252,7 +266,7 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab &
             model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_in.weight"]     = layer.c_mlp_fc_w;
             model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_in.bias"]       = layer.c_mlp_fc_b;
 
-            model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_out.weight"]    = layer.c_mlp_proj_w_trans;
+            model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_out.weight"]    = layer.c_mlp_proj_w;
             model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_out.bias"]      = layer.c_mlp_proj_b;
         }
     }
@@ -323,9 +337,26 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab &
                 return false;
             }
 
-            const size_t bpe = tensor->type == GGML_TYPE_I8 ? 1 : (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
+            if (0) {
+                static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
+                printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ftype_str[ftype], ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
+            }
 
-            if (nelements*bpe != ggml_nbytes(tensor)) {
+            size_t bpe = 0;
+
+            switch (ftype) {
+                case 0: bpe = ggml_type_size(GGML_TYPE_F32);  break;
+                case 1: bpe = ggml_type_size(GGML_TYPE_F16);  break;
+                case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break;
+                case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break;
+                default:
+                        {
+                            fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype);
+                            return false;
+                        }
+            };
+
+            if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
                 fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
                         __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
                 return false;
@@ -430,9 +461,9 @@ bool gptj_eval(
 
         // self-attention
         {
-            struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, ggml_transpose(ctx0, model.layers[il].c_attn_q_proj_w), cur);
-            struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, ggml_transpose(ctx0, model.layers[il].c_attn_k_proj_w), cur);
-            struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, ggml_transpose(ctx0, model.layers[il].c_attn_v_proj_w), cur);
+            struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].c_attn_q_proj_w, cur);
+            struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].c_attn_k_proj_w, cur);
+            struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].c_attn_v_proj_w, cur);
 
             // store key and value to memory
             if (N >= 1) {
@@ -481,11 +512,13 @@ bool gptj_eval(
 
             // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
             struct ggml_tensor * V_trans =
-                ggml_permute(ctx0,
-                        ggml_reshape_3d(ctx0,
-                            ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
-                            n_embd/n_head, n_head, n_past + N),
-                        1, 2, 0, 3);
+                ggml_cpy(ctx0,
+                        ggml_permute(ctx0,
+                            ggml_reshape_3d(ctx0,
+                                ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
+                                n_embd/n_head, n_head, n_past + N),
+                            1, 2, 0, 3),
+                        ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head));
 
             // KQV = transpose(V) * KQ_soft_max
             struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
@@ -500,7 +533,7 @@ bool gptj_eval(
 
             // projection (no bias)
             cur = ggml_mul_mat(ctx0,
-                    ggml_transpose(ctx0, model.layers[il].c_attn_proj_w),
+                    model.layers[il].c_attn_proj_w,
                     cur);
         }
 
@@ -511,7 +544,7 @@ bool gptj_eval(
         {
             // note here we pass inpSA instead of cur
             cur = ggml_mul_mat(ctx0,
-                    ggml_transpose(ctx0, model.layers[il].c_mlp_fc_w),
+                    model.layers[il].c_mlp_fc_w,
                     inpSA);
 
             cur = ggml_add(ctx0,
@@ -524,7 +557,7 @@ bool gptj_eval(
             // projection
             // cur = proj_w*cur + proj_b
             cur = ggml_mul_mat(ctx0,
-                    model.layers[il].c_mlp_proj_w_trans,
+                    model.layers[il].c_mlp_proj_w,
                     cur);
 
             cur = ggml_add(ctx0,
diff --git a/examples/gpt-j/quantize.cpp b/examples/gpt-j/quantize.cpp
new file mode 100644 (file)
index 0000000..29b92cf
--- /dev/null
@@ -0,0 +1,324 @@
+#include "ggml/ggml.h"
+
+#include "utils.h"
+
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <fstream>
+#include <map>
+#include <string>
+#include <vector>
+#include <regex>
+
+// default hparams (GPT-J 6B)
+struct gptj_hparams {
+    int32_t n_vocab = 50400;
+    int32_t n_ctx   = 2048;
+    int32_t n_embd  = 4096;
+    int32_t n_head  = 16;
+    int32_t n_layer = 28;
+    int32_t n_rot   = 64;
+    int32_t f16     = 1;
+};
+
+// quantize a model
+bool gptj_model_quantize(const std::string & fname_inp, const std::string & fname_out, int itype) {
+    ggml_type type = GGML_TYPE_Q4_1;
+
+    switch (itype) {
+        case 2: type = GGML_TYPE_Q4_0; break;
+        case 3: type = GGML_TYPE_Q4_1; break;
+        default: fprintf(stderr, "%s: invalid quantization type %d\n", __func__, itype); return 1;
+    };
+
+    if (type != GGML_TYPE_Q4_0 && type != GGML_TYPE_Q4_1) {
+        fprintf(stderr, "%s: invalid quantization type %d\n", __func__, type);
+        return false;
+    }
+
+    gpt_vocab vocab;
+
+    printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
+
+    auto finp = std::ifstream(fname_inp, std::ios::binary);
+    if (!finp) {
+        fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str());
+        return false;
+    }
+
+    auto fout = std::ofstream(fname_out, std::ios::binary);
+    if (!fout) {
+        fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str());
+        return false;
+    }
+
+    // verify magic
+    {
+        uint32_t magic;
+        finp.read((char *) &magic, sizeof(magic));
+        if (magic != 0x67676d6c) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
+            return false;
+        }
+
+        fout.write((char *) &magic, sizeof(magic));
+    }
+
+    gptj_hparams hparams;
+
+    // load hparams
+    {
+        finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
+        finp.read((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));
+        finp.read((char *) &hparams.n_embd,  sizeof(hparams.n_embd));
+        finp.read((char *) &hparams.n_head,  sizeof(hparams.n_head));
+        finp.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
+        finp.read((char *) &hparams.n_rot,   sizeof(hparams.n_rot));
+        finp.read((char *) &hparams.f16,     sizeof(hparams.f16));
+
+        printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
+        printf("%s: n_ctx   = %d\n", __func__, hparams.n_ctx);
+        printf("%s: n_embd  = %d\n", __func__, hparams.n_embd);
+        printf("%s: n_head  = %d\n", __func__, hparams.n_head);
+        printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
+        printf("%s: f16     = %d\n", __func__, hparams.f16);
+
+        fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
+        fout.write((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));
+        fout.write((char *) &hparams.n_embd,  sizeof(hparams.n_embd));
+        fout.write((char *) &hparams.n_head,  sizeof(hparams.n_head));
+        fout.write((char *) &hparams.n_layer, sizeof(hparams.n_layer));
+        fout.write((char *) &hparams.n_rot,   sizeof(hparams.n_rot));
+        fout.write((char *) &itype,           sizeof(hparams.f16));
+    }
+
+    // load vocab
+    {
+        int32_t n_vocab = 0;
+        finp.read ((char *) &n_vocab, sizeof(n_vocab));
+        fout.write((char *) &n_vocab, sizeof(n_vocab));
+
+        if (n_vocab != hparams.n_vocab) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
+                    __func__, fname_inp.c_str(), n_vocab, hparams.n_vocab);
+            return false;
+        }
+
+        std::string word;
+        for (int i = 0; i < n_vocab; i++) {
+            uint32_t len;
+            finp.read ((char *) &len, sizeof(len));
+            fout.write((char *) &len, sizeof(len));
+
+            word.resize(len);
+            finp.read ((char *) word.data(), len);
+            fout.write((char *) word.data(), len);
+
+            vocab.token_to_id[word] = i;
+            vocab.id_to_token[i] = word;
+        }
+    }
+
+    // load weights
+    {
+        size_t total_size_org = 0;
+        size_t total_size_new = 0;
+
+        std::vector<float> work;
+
+        std::vector<uint8_t>     data_u8;
+        std::vector<ggml_fp16_t> data_f16;
+        std::vector<float>       data_f32;
+
+        std::vector<int64_t> hist_all(1 << 4, 0);
+
+        while (true) {
+            int32_t n_dims;
+            int32_t length;
+            int32_t ftype;
+
+            finp.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+            finp.read(reinterpret_cast<char *>(&length), sizeof(length));
+            finp.read(reinterpret_cast<char *>(&ftype),  sizeof(ftype));
+
+            if (finp.eof()) {
+                break;
+            }
+
+            int32_t nelements = 1;
+            int32_t ne[2] = { 1, 1 };
+            for (int i = 0; i < n_dims; ++i) {
+                finp.read (reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+                nelements *= ne[i];
+            }
+
+            std::string name(length, 0);
+            finp.read (&name[0], length);
+
+            {
+                static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
+                printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ftype_str[ftype]);
+            }
+
+            // regexes of tensor names to be quantized
+            const std::vector<std::string> k_names = {
+                ".*weight",
+            };
+
+            bool quantize = false;
+            for (const auto & s : k_names) {
+                if (std::regex_match(name, std::regex(s))) {
+                    quantize = true;
+                    break;
+                }
+            }
+
+            // quantize only 2D tensors
+            quantize &= (n_dims == 2);
+
+            if (quantize) {
+                if (ftype != 0 && ftype != 1) {
+                    fprintf(stderr, "%s: unsupported ftype %d for integer quantization\n", __func__, ftype);
+                    return false;
+                }
+
+                if (ftype == 1) {
+                    data_f16.resize(nelements);
+                    finp.read(reinterpret_cast<char *>(data_f16.data()), nelements * sizeof(ggml_fp16_t));
+                    data_f32.resize(nelements);
+                    for (int i = 0; i < nelements; ++i) {
+                        data_f32[i] = ggml_fp16_to_fp32(data_f16[i]);
+                    }
+                } else {
+                    data_f32.resize(nelements);
+                    finp.read(reinterpret_cast<char *>(data_f32.data()), nelements * sizeof(float));
+                }
+
+                ftype = itype;
+            } else {
+                const int bpe = (ftype == 0) ? sizeof(float) : sizeof(uint16_t);
+
+                data_u8.resize(nelements*bpe);
+                finp.read(reinterpret_cast<char *>(data_u8.data()), nelements * bpe);
+            }
+
+            fout.write(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+            fout.write(reinterpret_cast<char *>(&length), sizeof(length));
+            fout.write(reinterpret_cast<char *>(&ftype),  sizeof(ftype));
+            for (int i = 0; i < n_dims; ++i) {
+                fout.write(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+            }
+            fout.write(&name[0], length);
+
+            if (quantize) {
+                printf("quantizing .. ");
+                work.resize(nelements); // for quantization
+
+                size_t cur_size = 0;
+                std::vector<int64_t> hist_cur(1 << 4, 0);
+
+                switch (type) {
+                    case GGML_TYPE_Q4_0:
+                        {
+                            cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
+                        } break;
+                    case GGML_TYPE_Q4_1:
+                        {
+                            cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
+                        } break;
+                    default:
+                        {
+                            fprintf(stderr, "%s: unsupported quantization type %d\n", __func__, type);
+                            return false;
+                        }
+                }
+
+                fout.write(reinterpret_cast<char *>(work.data()), cur_size);
+                total_size_new += cur_size;
+
+                printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float)/1024.0/1024.0, cur_size/1024.0/1024.0);
+                for (int i = 0; i < hist_cur.size(); ++i) {
+                    hist_all[i] += hist_cur[i];
+                }
+
+                for (int i = 0; i < hist_cur.size(); ++i) {
+                    printf("%5.3f ", hist_cur[i] / (float)nelements);
+                }
+                printf("\n");
+            } else {
+                printf("size = %8.3f MB\n", data_u8.size()/1024.0/1024.0);
+                fout.write(reinterpret_cast<char *>(data_u8.data()), data_u8.size());
+                total_size_new += data_u8.size();
+            }
+
+            total_size_org += nelements * sizeof(float);
+        }
+
+        printf("%s: model size  = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
+        printf("%s: quant size  = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0);
+
+        {
+            int64_t sum_all = 0;
+            for (int i = 0; i < hist_all.size(); ++i) {
+                sum_all += hist_all[i];
+            }
+
+            printf("%s: hist: ", __func__);
+            for (int i = 0; i < hist_all.size(); ++i) {
+                printf("%5.3f ", hist_all[i] / (float)sum_all);
+            }
+            printf("\n");
+        }
+    }
+
+    finp.close();
+    fout.close();
+
+    return true;
+}
+
+// usage:
+//  ./gpt-2-quantize models/gpt-2-117M/ggml-model.bin models/gpt-2-117M/ggml-model-quant.bin type
+//
+int main(int argc, char ** argv) {
+    if (argc != 4) {
+        fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]);
+        fprintf(stderr, "  type = 2 - q4_0\n");
+        fprintf(stderr, "  type = 3 - q4_1\n");
+        return 1;
+    }
+
+    const std::string fname_inp = argv[1];
+    const std::string fname_out = argv[2];
+
+    const int itype = atoi(argv[3]);
+
+    const int64_t t_main_start_us = ggml_time_us();
+
+    int64_t t_quantize_us = 0;
+
+    // load the model
+    {
+        const int64_t t_start_us = ggml_time_us();
+
+        if (!gptj_model_quantize(fname_inp, fname_out, itype)) {
+            fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
+            return 1;
+        }
+
+        t_quantize_us = ggml_time_us() - t_start_us;
+    }
+
+    // report timing
+    {
+        const int64_t t_main_end_us = ggml_time_us();
+
+        printf("\n");
+        printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0f);
+        printf("%s:    total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
+    }
+
+    return 0;
+}
index d091d3db16a0da5102f1de2292d69d6cb957024f..b61173ffd4d8dffe024e2bccb7c9aab590492e7b 100644 (file)
@@ -20,7 +20,7 @@ struct gpt_params {
     // sampling parameters
     int32_t top_k = 40;
     float   top_p = 0.9f;
-    float   temp  = 1.0f;
+    float   temp  = 0.9f;
 
     int32_t n_batch = 8; // batch size for prompt processing
 
@@ -81,4 +81,3 @@ gpt_vocab::id gpt_sample_top_k_top_p(
         double top_p,
         double temp,
         std::mt19937 & rng);
-
index c8fa83a831d3e71ac190042d70a5633d6355d4dd..c7f5ff54ea0f9e5b2d9d2c51c13965761b180c39 100644 (file)
@@ -13,3 +13,10 @@ set(TEST_TARGET whisper)
 add_executable(${TEST_TARGET} main.cpp common.cpp)
 target_link_libraries(${TEST_TARGET} PRIVATE whisper-cpp)
 target_include_directories(${TEST_TARGET} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/..)
+
+#
+# whisper-quantize
+
+set(TEST_TARGET whisper-quantize)
+add_executable(${TEST_TARGET} quantize.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml ggml_utils)
index 9e9b2dcebefb92b0e38a600d3c0fee78e6fe8717..749f99c885ebdeacc6db32d812f35079754baf75 100644 (file)
@@ -303,8 +303,9 @@ for name in list_vars.keys():
             data = data.astype(np.float32)
             ftype = 0
     else:
-        data = data.astype(np.float32)
-        ftype = 0
+        if n_dims < 3 and data.dtype != np.float32:
+            data = data.astype(np.float32)
+            ftype = 0
 
     #if name.startswith("encoder"):
     #    if name.endswith("mlp.0.weight") or \
index b8366b79f4581cc474930a99080d562f4374eb75..dd30ba4c473766169e10594283c1e33a66483891 100644 (file)
@@ -73,6 +73,7 @@ struct whisper_params {
     bool output_srt     = false;
     bool output_wts     = false;
     bool output_csv     = false;
+    bool output_jsn     = false;
     bool print_special  = false;
     bool print_colors   = false;
     bool print_progress = false;
@@ -80,6 +81,7 @@ struct whisper_params {
 
     std::string language = "en";
     std::string prompt;
+    std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
     std::string model    = "models/ggml-base.en.bin";
 
     std::vector<std::string> fname_inp = {};
@@ -127,7 +129,9 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-ovtt" || arg == "--output-vtt")     { params.output_vtt     = true; }
         else if (arg == "-osrt" || arg == "--output-srt")     { params.output_srt     = true; }
         else if (arg == "-owts" || arg == "--output-words")   { params.output_wts     = true; }
+        else if (arg == "-fp"   || arg == "--font-path")      { params.font_path      = argv[++i]; }
         else if (arg == "-ocsv" || arg == "--output-csv")     { params.output_csv     = true; }
+        else if (arg == "-oj"   || arg == "--output-json")    { params.output_jsn     = true; }
         else if (arg == "-of"   || arg == "--output-file")    { params.fname_out.emplace_back(argv[++i]); }
         else if (arg == "-ps"   || arg == "--print-special")  { params.print_special  = true; }
         else if (arg == "-pc"   || arg == "--print-colors")   { params.print_colors   = true; }
@@ -174,7 +178,9 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -ovtt,     --output-vtt        [%-7s] output result in a vtt file\n",                    params.output_vtt ? "true" : "false");
     fprintf(stderr, "  -osrt,     --output-srt        [%-7s] output result in a srt file\n",                    params.output_srt ? "true" : "false");
     fprintf(stderr, "  -owts,     --output-words      [%-7s] output script for generating karaoke video\n",     params.output_wts ? "true" : "false");
+    fprintf(stderr, "  -fp,       --font-path         [%-7s] path to a monospace font for karaoke video\n",     params.font_path.c_str());
     fprintf(stderr, "  -ocsv,     --output-csv        [%-7s] output result in a CSV file\n",                    params.output_csv ? "true" : "false");
+    fprintf(stderr, "  -oj,       --output-json       [%-7s] output result in a JSON file\n",                   params.output_jsn ? "true" : "false");
     fprintf(stderr, "  -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n",      "");
     fprintf(stderr, "  -ps,       --print-special     [%-7s] print special tokens\n",                           params.print_special ? "true" : "false");
     fprintf(stderr, "  -pc,       --print-colors      [%-7s] print colors\n",                                   params.print_colors ? "true" : "false");
@@ -193,7 +199,7 @@ struct whisper_print_user_data {
     const std::vector<std::vector<float>> * pcmf32s;
 };
 
-void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
+void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) {
     const auto & params  = *((whisper_print_user_data *) user_data)->params;
     const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
 
@@ -352,28 +358,157 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
     fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
 
     const int n_segments = whisper_full_n_segments(ctx);
+    fout << "start,end,text\n";
     for (int i = 0; i < n_segments; ++i) {
         const char * text = whisper_full_get_segment_text(ctx, i);
         const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
         const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
 
         //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
-        fout << 10 * t0 << ", " << 10 * t1 << ", \"" << text    << "\"\n";
+        fout << 10 * t0 << "," << 10 * t1 << ",\"" << text    << "\"\n";
     }
 
     return true;
 }
 
+bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params) {
+    std::ofstream fout(fname);
+    int indent = 0;
+
+    auto doindent = [&]() {
+        for (int i = 0; i < indent; i++) fout << "\t";
+    };
+
+    auto start_arr = [&](const char *name) {
+        doindent();
+        fout << "\"" << name << "\": [\n";
+        indent++;
+    };
+
+    auto end_arr = [&](bool end = false) {
+        indent--;
+        doindent();
+        fout << (end ? "]\n" : "},\n");
+    };
+
+    auto start_obj = [&](const char *name = nullptr) {
+        doindent();
+        if (name) {
+            fout << "\"" << name << "\": {\n";
+        } else {
+            fout << "{\n";
+        }
+        indent++;
+    };
+
+    auto end_obj = [&](bool end = false) {
+        indent--;
+        doindent();
+        fout << (end ? "}\n" : "},\n");
+    };
+
+    auto start_value = [&](const char *name) {
+        doindent();
+        fout << "\"" << name << "\": ";
+    };
+
+    auto value_s = [&](const char *name, const char *val, bool end = false) {
+        start_value(name);
+        fout << "\"" << val << (end ? "\"\n" : "\",\n");
+    };
+
+    auto end_value = [&](bool end = false) {
+        fout << (end ? "\n" : ",\n");
+    };
+
+    auto value_i = [&](const char *name, const int64_t val, bool end = false) {
+        start_value(name);
+        fout << val;
+        end_value(end);
+    };
+
+    auto value_b = [&](const char *name, const bool val, bool end = false) {
+        start_value(name);
+        fout << (val ? "true" : "false");
+        end_value(end);
+    };
+
+    if (!fout.is_open()) {
+        fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
+        return false;
+    }
+
+    fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
+    start_obj();
+        value_s("systeminfo", whisper_print_system_info());
+        start_obj("model");
+            value_s("type", whisper_model_type_readable(ctx));
+            value_b("multilingual", whisper_is_multilingual(ctx));
+            value_i("vocab", whisper_model_n_vocab(ctx));
+            start_obj("audio");
+                value_i("ctx", whisper_model_n_audio_ctx(ctx));
+                value_i("state", whisper_model_n_audio_state(ctx));
+                value_i("head", whisper_model_n_audio_head(ctx));
+                value_i("layer", whisper_model_n_audio_layer(ctx), true);
+            end_obj();
+            start_obj("text");
+                value_i("ctx", whisper_model_n_text_ctx(ctx));
+                value_i("state", whisper_model_n_text_state(ctx));
+                value_i("head", whisper_model_n_text_head(ctx));
+                value_i("leyer", whisper_model_n_text_layer(ctx), true);
+            end_obj();
+            value_i("mels", whisper_model_n_mels(ctx));
+            value_i("f16", whisper_model_f16(ctx), true);
+        end_obj();
+        start_obj("params");
+            value_s("model", params.model.c_str());
+            value_s("language", params.language.c_str());
+            value_b("translate", params.translate, true);
+        end_obj();
+        start_obj("result");
+            value_s("language", whisper_lang_str(whisper_full_lang_id(ctx)), true);
+        end_obj();
+        start_arr("transcription");
+
+            const int n_segments = whisper_full_n_segments(ctx);
+            for (int i = 0; i < n_segments; ++i) {
+                const char * text = whisper_full_get_segment_text(ctx, i);
+                const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
+                const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+
+                start_obj();
+                    start_obj("timestanps");
+                        value_s("from", to_timestamp(t0, true).c_str());
+                        value_s("to", to_timestamp(t1, true).c_str(), true);
+                    end_obj();
+                    start_obj("offsets");
+                        value_i("from", t0 * 10);
+                        value_i("to", t1 * 10, true);
+                    end_obj();
+                    value_s("text", text, true);
+                end_obj(i == (n_segments - 1));
+            }
+
+        end_arr(true);
+    end_obj(true);
+    return true;
+}
+
 // karaoke video generation
 // outputs a bash script that uses ffmpeg to generate a video with the subtitles
 // TODO: font parameter adjustments
-bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & /*params*/, float t_sec) {
+bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) {
     std::ofstream fout(fname);
 
     fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
 
-    // TODO: become parameter
-    static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
+    static const char * font = params.font_path.c_str();
+
+    std::ifstream fin(font);
+    if (!fin.is_open()) {
+        fprintf(stderr, "%s: font not found at '%s', please specify a monospace font with -fp\n", __func__, font);
+        return false;
+    }
 
     fout << "#!/bin/bash" << "\n";
     fout << "\n";
@@ -607,7 +742,7 @@ int main(int argc, char ** argv) {
             {
                 static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
 
-                wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
+                wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
                     bool is_aborted = *(bool*)user_data;
                     return !is_aborted;
                 };
@@ -653,6 +788,12 @@ int main(int argc, char ** argv) {
                 const auto fname_csv = fname_out + ".csv";
                 output_csv(ctx, fname_csv.c_str());
             }
+
+            // output to JSON file
+            if (params.output_jsn) {
+                const auto fname_jsn = fname_out + ".json";
+                output_json(ctx, fname_jsn.c_str(), params);
+            }
         }
     }
 
diff --git a/examples/whisper/quantize.cpp b/examples/whisper/quantize.cpp
new file mode 100644 (file)
index 0000000..8042d69
--- /dev/null
@@ -0,0 +1,373 @@
+#include "ggml/ggml.h"
+
+#include "utils.h"
+
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <fstream>
+#include <map>
+#include <string>
+#include <vector>
+#include <regex>
+
+// default hparams (Whisper tiny)
+struct whisper_hparams {
+    int32_t n_vocab       = 51864;
+    int32_t n_audio_ctx   = 1500;
+    int32_t n_audio_state = 384;
+    int32_t n_audio_head  = 6;
+    int32_t n_audio_layer = 4;
+    int32_t n_text_ctx    = 448;
+    int32_t n_text_state  = 384;
+    int32_t n_text_head   = 6;
+    int32_t n_text_layer  = 4;
+    int32_t n_mels        = 80;
+    int32_t f16           = 1;
+};
+
+struct whisper_filters {
+    int32_t n_mel;
+    int32_t n_fft;
+
+    std::vector<float> data;
+};
+
+// quantize a model
+bool whisper_model_quantize(const std::string & fname_inp, const std::string & fname_out, int itype) {
+    ggml_type type = GGML_TYPE_Q4_1;
+
+    switch (itype) {
+        case 2: type = GGML_TYPE_Q4_0; break;
+        case 3: type = GGML_TYPE_Q4_1; break;
+        default: fprintf(stderr, "%s: invalid quantization type %d\n", __func__, itype); return 1;
+    };
+
+    if (type != GGML_TYPE_Q4_0 && type != GGML_TYPE_Q4_1) {
+        fprintf(stderr, "%s: invalid quantization type %d\n", __func__, type);
+        return false;
+    }
+
+    gpt_vocab vocab;
+
+    printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
+
+    auto finp = std::ifstream(fname_inp, std::ios::binary);
+    if (!finp) {
+        fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str());
+        return false;
+    }
+
+    auto fout = std::ofstream(fname_out, std::ios::binary);
+    if (!fout) {
+        fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str());
+        return false;
+    }
+
+    // verify magic
+    {
+        uint32_t magic;
+        finp.read((char *) &magic, sizeof(magic));
+        if (magic != 0x67676d6c) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
+            return false;
+        }
+
+        fout.write((char *) &magic, sizeof(magic));
+    }
+
+    whisper_hparams hparams;
+
+    // load hparams
+    {
+        finp.read((char *) &hparams.n_vocab,       sizeof(hparams.n_vocab));
+        finp.read((char *) &hparams.n_audio_ctx,   sizeof(hparams.n_audio_ctx));
+        finp.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
+        finp.read((char *) &hparams.n_audio_head,  sizeof(hparams.n_audio_head));
+        finp.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
+        finp.read((char *) &hparams.n_text_ctx,    sizeof(hparams.n_text_ctx));
+        finp.read((char *) &hparams.n_text_state,  sizeof(hparams.n_text_state));
+        finp.read((char *) &hparams.n_text_head,   sizeof(hparams.n_text_head));
+        finp.read((char *) &hparams.n_text_layer,  sizeof(hparams.n_text_layer));
+        finp.read((char *) &hparams.n_mels,        sizeof(hparams.n_mels));
+        finp.read((char *) &hparams.f16,           sizeof(hparams.f16));
+
+        fprintf(stderr, "%s: n_vocab       = %d\n", __func__, hparams.n_vocab);
+        fprintf(stderr, "%s: n_audio_ctx   = %d\n", __func__, hparams.n_audio_ctx);
+        fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
+        fprintf(stderr, "%s: n_audio_head  = %d\n", __func__, hparams.n_audio_head);
+        fprintf(stderr, "%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
+        fprintf(stderr, "%s: n_text_ctx    = %d\n", __func__, hparams.n_text_ctx);
+        fprintf(stderr, "%s: n_text_state  = %d\n", __func__, hparams.n_text_state);
+        fprintf(stderr, "%s: n_text_head   = %d\n", __func__, hparams.n_text_head);
+        fprintf(stderr, "%s: n_text_layer  = %d\n", __func__, hparams.n_text_layer);
+        fprintf(stderr, "%s: n_mels        = %d\n", __func__, hparams.n_mels);
+        fprintf(stderr, "%s: f16           = %d\n", __func__, hparams.f16);
+
+        fout.write((char *) &hparams.n_vocab,       sizeof(hparams.n_vocab));
+        fout.write((char *) &hparams.n_audio_ctx,   sizeof(hparams.n_audio_ctx));
+        fout.write((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
+        fout.write((char *) &hparams.n_audio_head,  sizeof(hparams.n_audio_head));
+        fout.write((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
+        fout.write((char *) &hparams.n_text_ctx,    sizeof(hparams.n_text_ctx));
+        fout.write((char *) &hparams.n_text_state,  sizeof(hparams.n_text_state));
+        fout.write((char *) &hparams.n_text_head,   sizeof(hparams.n_text_head));
+        fout.write((char *) &hparams.n_text_layer,  sizeof(hparams.n_text_layer));
+        fout.write((char *) &hparams.n_mels,        sizeof(hparams.n_mels));
+        fout.write((char *) &itype,                 sizeof(hparams.f16));
+    }
+
+    // load mel filters
+    {
+        whisper_filters filters;
+
+        finp.read ((char *) &filters.n_mel, sizeof(filters.n_mel));
+        fout.write((char *) &filters.n_mel, sizeof(filters.n_mel));
+        finp.read ((char *) &filters.n_fft, sizeof(filters.n_fft));
+        fout.write((char *) &filters.n_fft, sizeof(filters.n_fft));
+
+        filters.data.resize(filters.n_mel * filters.n_fft);
+        finp.read ((char *) filters.data.data(), filters.data.size() * sizeof(float));
+        fout.write((char *) filters.data.data(), filters.data.size() * sizeof(float));
+    }
+
+    // load vocab
+    {
+        int32_t n_vocab = 0;
+        finp.read ((char *) &n_vocab, sizeof(n_vocab));
+        fout.write((char *) &n_vocab, sizeof(n_vocab));
+
+        //if (n_vocab != hparams.n_vocab) {
+        //    fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
+        //            __func__, fname_inp.c_str(), n_vocab, hparams.n_vocab);
+        //    return false;
+        //}
+
+        std::string word;
+        for (int i = 0; i < n_vocab; i++) {
+            uint32_t len;
+            finp.read ((char *) &len, sizeof(len));
+            fout.write((char *) &len, sizeof(len));
+
+            word.resize(len);
+            finp.read ((char *) word.data(), len);
+            fout.write((char *) word.data(), len);
+
+            vocab.token_to_id[word] = i;
+            vocab.id_to_token[i] = word;
+        }
+    }
+
+    // load weights
+    {
+        size_t total_size_org = 0;
+        size_t total_size_new = 0;
+
+        std::vector<float> work;
+
+        std::vector<uint8_t>     data_u8;
+        std::vector<ggml_fp16_t> data_f16;
+        std::vector<float>       data_f32;
+
+        std::vector<int64_t> hist_all(1 << 4, 0);
+
+        while (true) {
+            int32_t n_dims;
+            int32_t length;
+            int32_t ftype;
+
+            finp.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+            finp.read(reinterpret_cast<char *>(&length), sizeof(length));
+            finp.read(reinterpret_cast<char *>(&ftype),  sizeof(ftype));
+
+            if (finp.eof()) {
+                break;
+            }
+
+            int32_t nelements = 1;
+            int32_t ne[3] = { 1, 1, 1 };
+            for (int i = 0; i < n_dims; ++i) {
+                finp.read (reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+                nelements *= ne[i];
+            }
+
+            std::string name(length, 0);
+            finp.read (&name[0], length);
+
+            {
+                static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
+                printf("%48s - [%5d, %5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ne[2], ftype_str[ftype]);
+            }
+
+            // regexes of tensor names to not be quantized
+            const std::vector<std::string> k_names = {
+                //"encoder.*",
+                "encoder.conv1.bias",
+                "encoder.conv2.bias",
+                "encoder.positional_embedding",
+                "decoder.positional_embedding",
+            };
+
+            bool quantize = true;
+            for (const auto & s : k_names) {
+                if (std::regex_match(name, std::regex(s))) {
+                    quantize = false;
+                    break;
+                }
+            }
+
+            // quantize only 2D and 3D tensors
+            quantize &= (n_dims == 2);
+
+            if (quantize) {
+                if (ftype != 0 && ftype != 1) {
+                    fprintf(stderr, "%s: unsupported ftype %d for integer quantization\n", __func__, ftype);
+                    return false;
+                }
+
+                if (ftype == 1) {
+                    data_f16.resize(nelements);
+                    finp.read(reinterpret_cast<char *>(data_f16.data()), nelements * sizeof(ggml_fp16_t));
+                    data_f32.resize(nelements);
+                    for (int i = 0; i < nelements; ++i) {
+                        data_f32[i] = ggml_fp16_to_fp32(data_f16[i]);
+                    }
+                } else {
+                    data_f32.resize(nelements);
+                    finp.read(reinterpret_cast<char *>(data_f32.data()), nelements * sizeof(float));
+                }
+
+                ftype = itype;
+            } else {
+                const int bpe = (ftype == 0) ? sizeof(float) : sizeof(uint16_t);
+
+                data_u8.resize(nelements*bpe);
+                finp.read(reinterpret_cast<char *>(data_u8.data()), nelements * bpe);
+            }
+
+            fout.write(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+            fout.write(reinterpret_cast<char *>(&length), sizeof(length));
+            fout.write(reinterpret_cast<char *>(&ftype),  sizeof(ftype));
+            for (int i = 0; i < n_dims; ++i) {
+                fout.write(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+            }
+            fout.write(&name[0], length);
+
+            if (quantize) {
+                printf("quantizing .. ");
+                work.resize(nelements); // for quantization
+
+                size_t cur_size = 0;
+                std::vector<int64_t> hist_cur(1 << 4, 0);
+
+                switch (type) {
+                    case GGML_TYPE_Q4_0:
+                        {
+                            cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
+                        } break;
+                    case GGML_TYPE_Q4_1:
+                        {
+                            cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
+                        } break;
+                    default:
+                        {
+                            fprintf(stderr, "%s: unsupported quantization type %d\n", __func__, type);
+                            return false;
+                        }
+                }
+
+                fout.write(reinterpret_cast<char *>(work.data()), cur_size);
+                total_size_new += cur_size;
+
+                printf("size = %8.3f MB -> %8.3f MB | hist: ", nelements * sizeof(float)/1024.0/1024.0, cur_size/1024.0/1024.0);
+                for (int i = 0; i < hist_cur.size(); ++i) {
+                    hist_all[i] += hist_cur[i];
+                }
+
+                for (int i = 0; i < hist_cur.size(); ++i) {
+                    printf("%5.3f ", hist_cur[i] / (float)nelements);
+                }
+                printf("\n");
+            } else {
+                printf("size = %8.3f MB\n", data_u8.size()/1024.0/1024.0);
+                fout.write(reinterpret_cast<char *>(data_u8.data()), data_u8.size());
+                total_size_new += data_u8.size();
+            }
+
+            total_size_org += nelements * sizeof(float);
+        }
+
+        printf("%s: model size  = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
+        printf("%s: quant size  = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0);
+
+        {
+            int64_t sum_all = 0;
+            for (int i = 0; i < hist_all.size(); ++i) {
+                sum_all += hist_all[i];
+            }
+
+            printf("%s: hist: ", __func__);
+            for (int i = 0; i < hist_all.size(); ++i) {
+                printf("%5.3f ", hist_all[i] / (float)sum_all);
+            }
+            printf("\n");
+        }
+    }
+
+    finp.close();
+    fout.close();
+
+    return true;
+}
+
+// usage:
+//  ./gpt-2-quantize models/gpt-2-117M/ggml-model.bin models/gpt-2-117M/ggml-model-quant.bin type
+//
+int main(int argc, char ** argv) {
+    if (argc != 4) {
+        fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]);
+        fprintf(stderr, "  type = 2 - q4_0\n");
+        fprintf(stderr, "  type = 3 - q4_1\n");
+        return 1;
+    }
+
+    // needed to initialize f16 tables
+    {
+        struct ggml_init_params params = { 0, NULL };
+        struct ggml_context * ctx = ggml_init(params);
+        ggml_free(ctx);
+    }
+
+    const std::string fname_inp = argv[1];
+    const std::string fname_out = argv[2];
+
+    const int itype = atoi(argv[3]);
+
+    const int64_t t_main_start_us = ggml_time_us();
+
+    int64_t t_quantize_us = 0;
+
+    // load the model
+    {
+        const int64_t t_start_us = ggml_time_us();
+
+        if (!whisper_model_quantize(fname_inp, fname_out, itype)) {
+            fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
+            return 1;
+        }
+
+        t_quantize_us = ggml_time_us() - t_start_us;
+    }
+
+    // report timing
+    {
+        const int64_t t_main_end_us = ggml_time_us();
+
+        printf("\n");
+        printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0f);
+        printf("%s:    total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
+    }
+
+    return 0;
+}
index 3a21581c682c641ac20ccc2de7e7c25470f0df64..f44e5034bb804c711bd91b24cee9a46468fa1e65 100644 (file)
@@ -218,14 +218,14 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
     { "su",  { 98,  "sundanese",      } },
 };
 
-static const size_t MB = 1024*1024;
+static const size_t MB = 1ull*1024*1024;
 
 static const std::map<e_model, size_t> MEM_REQ_SCRATCH0 = {
-    { MODEL_TINY,     12ull*MB },
-    { MODEL_BASE,     15ull*MB },
-    { MODEL_SMALL,    23ull*MB },
-    { MODEL_MEDIUM,   31ull*MB },
-    { MODEL_LARGE,    38ull*MB },
+    { MODEL_TINY,     14ull*MB },
+    { MODEL_BASE,     18ull*MB },
+    { MODEL_SMALL,    28ull*MB },
+    { MODEL_MEDIUM,   36ull*MB },
+    { MODEL_LARGE,    44ull*MB },
 };
 
 static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
@@ -547,13 +547,11 @@ struct whisper_decoder {
     std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
 };
 
-struct whisper_context {
-    int64_t t_load_us   = 0;
-    int64_t t_mel_us    = 0;
+struct whisper_state {
     int64_t t_sample_us = 0;
     int64_t t_encode_us = 0;
     int64_t t_decode_us = 0;
-    int64_t t_start_us  = 0;
+    int64_t t_mel_us = 0;
 
     int32_t n_sample = 0; // number of tokens sampled
     int32_t n_encode = 0; // number of encoder calls
@@ -561,16 +559,10 @@ struct whisper_context {
     int32_t n_fail_p = 0; // number of logprob threshold failures
     int32_t n_fail_h = 0; // number of entropy threshold failures
 
-    ggml_type wtype; // weight type (FP32 or FP16)
-
-    whisper_mel mel;
-
-    whisper_model model;
-    whisper_vocab vocab;
-
     // cross-attention KV cache for the decoders
     // shared between all decoders
     whisper_kv_cache kv_cross;
+    whisper_mel mel;
 
     whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
 
@@ -635,6 +627,19 @@ struct whisper_context {
     }
 };
 
+struct whisper_context {
+    int64_t t_load_us = 0;
+    int64_t t_start_us = 0;
+
+    ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 or FP16)
+
+    whisper_model model;
+    whisper_vocab vocab;
+    whisper_state * state = nullptr;
+
+    std::string path_model; // populated by whisper_init_from_file()
+};
+
 template<typename T>
 static void read_safe(whisper_model_loader * loader, T & dest) {
     loader->read(loader->context, &dest, sizeof(T));
@@ -821,32 +826,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         wctx.model.buf = new std::vector<uint8_t>();
         wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type));
 
-        if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) {
-            fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
-            return false;
-        }
-
-        {
-            const size_t memory_size = ggml_nbytes(wctx.decoders[0].kv_self.k) + ggml_nbytes(wctx.decoders[0].kv_self.v);
-            fprintf(stderr, "%s: kv self size  = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
-        }
-
-        if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) {
-            fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
-            return false;
-        }
-
-        {
-            const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v);
-            fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
-        }
-
-        wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
-
-        wctx.buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(model.type));
-        wctx.buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(model.type));
-        wctx.buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(model.type));
-        wctx.buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(model.type));
+        // we skip initialization of the state until it is needed
+        // because it might be that state will always be provided externally.
     }
 
     // load mel filters
@@ -929,17 +910,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
                 vocab.id_to_token[i] = word;
             }
         }
-
-        wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
-
-        wctx.logits_id.reserve(n_vocab);
-
-        // TAGS: WHISPER_DECODER_INIT
-        wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx);
-
-        wctx.decoders[0].probs.reserve   (vocab.n_vocab);
-        wctx.decoders[0].logits.reserve  (vocab.n_vocab);
-        wctx.decoders[0].logprobs.reserve(vocab.n_vocab);
     }
 
     size_t ctx_size = 0;
@@ -1339,33 +1309,34 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         }
     }
 
-    wctx.rng = std::mt19937(0);
-
     wctx.t_load_us = ggml_time_us() - t_start_us;
 
     return true;
 }
 
-// evaluate the encoder
+// evaluate the encoder with the given state
 //
 // given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
 // part of the transformer model and returns the encoded features
 //
-//   - model:      the model
+//   - wctx:      the model
+//   - wstate:     the state of the encoder
 //   - n_threads:  number of threads to use
 //   - mel_offset: offset in the mel spectrogram (i.e. audio offset)
 //
-static bool whisper_encode(
+static bool whisper_encode_internal(
         whisper_context & wctx,
+          whisper_state & wstate,
               const int   mel_offset,
-              const int   n_threads) {
+              const int   n_threads){
+
     const int64_t t_start_us = ggml_time_us();
 
     const auto & model   = wctx.model;
-    const auto & mel_inp = wctx.mel;
+    const auto & mel_inp = wstate.mel;
     const auto & hparams = model.hparams;
 
-    const int n_ctx   = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
+    const int n_ctx   = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
     const int n_state = hparams.n_audio_state;
     const int n_head  = hparams.n_audio_head;
     const int n_layer = hparams.n_audio_layer;
@@ -1374,12 +1345,12 @@ static bool whisper_encode(
     assert(mel_inp.n_mel == n_mels);
 
     struct ggml_init_params params;
-    params.mem_size   = wctx.buf_compute.size();
-    params.mem_buffer = wctx.buf_compute.data();
+    params.mem_size   = wstate.buf_compute.size();
+    params.mem_buffer = wstate.buf_compute.data();
 
     struct ggml_context * ctx0 = ggml_init(params);
 
-    wctx.use_buf(ctx0, 0);
+    wstate.use_buf(ctx0, 0);
 
     struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
     assert(mel->type == GGML_TYPE_F32);
@@ -1401,30 +1372,30 @@ static bool whisper_encode(
 
     // convolution + gelu
     {
-        wctx.use_buf(ctx0, 1);
+        wstate.use_buf(ctx0, 1);
 
         cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
         cur = ggml_add(ctx0,
-                ggml_repeat(ctx0,
-                    model.e_conv_1_b,
-                    cur),
-                cur);
+            ggml_repeat(ctx0,
+                model.e_conv_1_b,
+                cur),
+            cur);
 
         cur = ggml_gelu(ctx0, cur);
 
-        wctx.use_buf(ctx0, 0);
+        wstate.use_buf(ctx0, 0);
 
         cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
         cur = ggml_add(ctx0,
-                ggml_repeat(ctx0,
-                    model.e_conv_2_b,
-                    cur),
-                cur);
+            ggml_repeat(ctx0,
+                model.e_conv_2_b,
+                cur),
+            cur);
 
         cur = ggml_gelu(ctx0, cur);
     }
 
-    wctx.use_buf(ctx0, 3);
+    wstate.use_buf(ctx0, 3);
 
     // ===================================================================
     // NOTE: experimenting with partial evaluation of the encoder (ignore)
@@ -1459,54 +1430,54 @@ static bool whisper_encode(
 
         // norm
         {
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             cur = ggml_norm(ctx0, inpL);
 
             // cur = ln_0_w*cur + ln_0_b
             cur = ggml_add(ctx0,
-                    ggml_mul(ctx0,
-                        ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
-                        cur),
-                    ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
+                ggml_mul(ctx0,
+                    ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
+                    cur),
+                ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
         }
 
         // self-attention
         {
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
-                    layer.attn_q_w,
-                    cur);
+                layer.attn_q_w,
+                cur);
 
             Qcur = ggml_add(ctx0,
-                    ggml_repeat(ctx0,
-                        layer.attn_q_b,
-                        Qcur),
-                    Qcur);
+                ggml_repeat(ctx0,
+                    layer.attn_q_b,
+                    Qcur),
+                Qcur);
 
             //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
 
             // note: no bias for Key
             struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
-                    layer.attn_k_w,
-                    cur);
+                layer.attn_k_w,
+                cur);
 
             //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
 
             struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
-                    layer.attn_v_w,
-                    cur);
+                layer.attn_v_w,
+                cur);
 
             Vcur = ggml_add(ctx0,
-                    ggml_repeat(ctx0,
-                        layer.attn_v_b,
-                        Vcur),
-                    Vcur);
+                ggml_repeat(ctx0,
+                    layer.attn_v_b,
+                    Vcur),
+                Vcur);
 
             // ------
 
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
 #ifdef WHISPER_USE_FLASH_ATTN
             struct ggml_tensor * Q =
@@ -1583,29 +1554,29 @@ static bool whisper_encode(
 #endif
             struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             cur = ggml_cpy(ctx0,
-                    KQV_merged,
-                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
+                KQV_merged,
+                ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
         }
 
         // projection
         {
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             cur = ggml_mul_mat(ctx0,
-                    layer.attn_ln_1_w,
-                    cur);
+                layer.attn_ln_1_w,
+                cur);
 
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             cur = ggml_add(ctx0,
-                    ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
-                    cur);
+                ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
+                cur);
         }
 
-        wctx.use_buf(ctx0, 2);
+        wstate.use_buf(ctx0, 2);
 
         // add the input
         cur = ggml_add(ctx0, cur, inpL);
@@ -1616,61 +1587,61 @@ static bool whisper_encode(
         {
             // norm
             {
-                wctx.use_buf(ctx0, 0);
+                wstate.use_buf(ctx0, 0);
 
                 cur = ggml_norm(ctx0, inpFF);
 
-                wctx.use_buf(ctx0, 1);
+                wstate.use_buf(ctx0, 1);
 
                 // cur = mlp_ln_w*cur + mlp_ln_b
                 cur = ggml_add(ctx0,
-                        ggml_mul(ctx0,
-                            ggml_repeat(ctx0, layer.mlp_ln_w, cur),
-                            cur),
-                        ggml_repeat(ctx0, layer.mlp_ln_b, cur));
+                    ggml_mul(ctx0,
+                        ggml_repeat(ctx0, layer.mlp_ln_w, cur),
+                        cur),
+                    ggml_repeat(ctx0, layer.mlp_ln_b, cur));
             }
 
 #ifdef WHISPER_USE_FLASH_FF
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             cur = ggml_flash_ff(ctx0,
-                    ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, n_ctx)),
-                    layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
+                ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.wtype, n_state, n_ctx)),
+                layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
 #else
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             // fully connected
             cur = ggml_mul_mat(ctx0,
-                    layer.mlp_0_w,
-                    cur);
+                layer.mlp_0_w,
+                cur);
 
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             cur = ggml_add(ctx0,
-                    ggml_repeat(ctx0, layer.mlp_0_b, cur),
-                    cur);
+                ggml_repeat(ctx0, layer.mlp_0_b, cur),
+                cur);
 
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             // GELU activation
             cur = ggml_gelu(ctx0, cur);
 
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             // projection
             cur = ggml_mul_mat(ctx0,
-                    layer.mlp_1_w,
-                    cur);
+                layer.mlp_1_w,
+                cur);
 
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             cur = ggml_add(ctx0,
-                    ggml_repeat(ctx0, layer.mlp_1_b, cur),
-                    cur);
+                ggml_repeat(ctx0, layer.mlp_1_b, cur),
+                cur);
 #endif
         }
 
-        wctx.use_buf(ctx0, 3);
+        wstate.use_buf(ctx0, 3);
 
         inpL = ggml_add(ctx0, cur, inpFF);
     }
@@ -1679,21 +1650,21 @@ static bool whisper_encode(
 
     // norm
     {
-        wctx.use_buf(ctx0, 0);
+        wstate.use_buf(ctx0, 0);
 
         cur = ggml_norm(ctx0, cur);
 
-        wctx.use_buf(ctx0, 1);
+        wstate.use_buf(ctx0, 1);
 
         // cur = ln_f_g*cur + ln_f_b
         cur = ggml_add(ctx0,
-                ggml_mul(ctx0,
-                    ggml_repeat(ctx0, model.e_ln_w, cur),
-                    cur),
-                ggml_repeat(ctx0, model.e_ln_b, cur));
+            ggml_mul(ctx0,
+                ggml_repeat(ctx0, model.e_ln_w, cur),
+                cur),
+            ggml_repeat(ctx0, model.e_ln_b, cur));
     }
 
-    wctx.use_buf(ctx0, -1);
+    wstate.use_buf(ctx0, -1);
 
     // run the computation
     {
@@ -1701,7 +1672,7 @@ static bool whisper_encode(
         gf.n_threads = n_threads;
 
         ggml_build_forward_expand(&gf, cur);
-        ggml_graph_compute       (ctx0, &gf);
+        ggml_graph_compute(ctx0, &gf);
 
         //ggml_graph_print(&gf);
     }
@@ -1731,34 +1702,34 @@ static bool whisper_encode(
         cur->src1 = nullptr;
 
         for (int il = 0; il < model.hparams.n_text_layer; ++il) {
-            auto & layer = model.layers_decoder[il];
+            auto& layer = model.layers_decoder[il];
 
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
-            struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
-                    layer.cross_attn_k_w,
-                    cur);
+            struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
+                layer.cross_attn_k_w,
+                cur);
 
-            Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
+            Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state) / n_head, -0.25)));
 
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
-            struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
-                    layer.cross_attn_v_w,
-                    cur);
+            struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
+                layer.cross_attn_v_w,
+                cur);
 
             Vcross = ggml_add(ctx0,
-                    ggml_repeat(ctx0,
-                        layer.cross_attn_v_b,
-                        Vcross),
-                    Vcross);
+                ggml_repeat(ctx0,
+                    layer.cross_attn_v_b,
+                    Vcross),
+                Vcross);
 
-            wctx.use_buf(ctx0, -1);
+            wstate.use_buf(ctx0, -1);
 
-            //struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
-            //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
-            struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*n_ctx));
-            struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*n_ctx));
+            //struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
+            //struct ggml_tensor * v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
+            struct ggml_tensor* k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
+            struct ggml_tensor* v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx));
 
             ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
             ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
@@ -1772,15 +1743,15 @@ static bool whisper_encode(
 
     //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
     //        ggml_used_mem(ctx0)/1024.0/1024.0,
-    //        wctx.get_buf_max_mem(0)/1024.0/1024.0,
-    //        wctx.get_buf_max_mem(1)/1024.0/1024.0,
-    //        wctx.get_buf_max_mem(2)/1024.0/1024.0,
-    //        wctx.get_buf_max_mem(3)/1024.0/1024.0);
+    //        wstate.get_buf_max_mem(0)/1024.0/1024.0,
+    //        wstate.get_buf_max_mem(1)/1024.0/1024.0,
+    //        wstate.get_buf_max_mem(2)/1024.0/1024.0,
+    //        wstate.get_buf_max_mem(3)/1024.0/1024.0);
 
     ggml_free(ctx0);
 
-    wctx.t_encode_us += ggml_time_us() - t_start_us;
-    wctx.n_encode++;
+    wstate.t_encode_us += ggml_time_us() - t_start_us;
+    wstate.n_encode++;
 
     return true;
 }
@@ -1795,8 +1766,9 @@ static bool whisper_encode(
 //   - n_tokens:   number of tokens in the prompt
 //   - n_past:     number of past tokens to prefix the prompt with
 //
-static bool whisper_decode(
+static bool whisper_decode_internal(
         whisper_context & wctx,
+          whisper_state & wstate,
         whisper_decoder & decoder,
     const whisper_token * tokens,
               const int   n_tokens,
@@ -1811,7 +1783,7 @@ static bool whisper_decode(
 
     WHISPER_ASSERT(!!kv_self.ctx);
 
-    auto & logits_out = wctx.logits;
+    auto & logits_out = wstate.logits;
 
     const int n_vocab = hparams.n_vocab;
 
@@ -1821,13 +1793,13 @@ static bool whisper_decode(
     const int n_layer = hparams.n_text_layer;
 
     const int N = n_tokens;
-    const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
+    const int M = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
 
     //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
 
     struct ggml_init_params params;
-    params.mem_size   = wctx.buf_compute.size();
-    params.mem_buffer = wctx.buf_compute.data();
+    params.mem_size   = wstate.buf_compute.size();
+    params.mem_buffer = wstate.buf_compute.data();
 
     struct ggml_context * ctx0 = ggml_init(params);
 
@@ -1842,7 +1814,7 @@ static bool whisper_decode(
         ((int32_t *) position->data)[i] = n_past + i;
     }
 
-    wctx.use_buf(ctx0, 3);
+    wstate.use_buf(ctx0, 3);
 
     // token encoding + position encoding
     struct ggml_tensor * cur =
@@ -1857,7 +1829,7 @@ static bool whisper_decode(
 
         // norm
         {
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             cur = ggml_norm(ctx0, inpL);
 
@@ -1871,8 +1843,6 @@ static bool whisper_decode(
 
         // self-attention
         {
-            wctx.use_buf(ctx0, 1);
-
             struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
                     layer.attn_q_w,
                     cur);
@@ -1913,7 +1883,7 @@ static bool whisper_decode(
 
             // ------
 
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             struct ggml_tensor * Q =
                 ggml_permute(ctx0,
@@ -1929,13 +1899,11 @@ static bool whisper_decode(
                             n_state/n_head, n_head, n_past + N),
                         0, 2, 1, 3);
 
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             // K * Q
             struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
 
-            wctx.use_buf(ctx0, 0);
-
             //struct ggml_tensor * KQ_scaled =
             //    ggml_scale(ctx0,
             //            KQ,
@@ -1944,20 +1912,16 @@ static bool whisper_decode(
 
             struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
 
-            wctx.use_buf(ctx0, 1);
-
             struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
 
-            wctx.use_buf(ctx0, 0);
-
             struct ggml_tensor * V_trans =
-                ggml_permute(ctx0,
-                        ggml_reshape_3d(ctx0,
-                            ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state),
-                            n_state/n_head, n_head, n_past + N),
-                        1, 2, 0, 3);
-
-            wctx.use_buf(ctx0, 1);
+                ggml_cpy(ctx0,
+                        ggml_permute(ctx0,
+                            ggml_reshape_3d(ctx0,
+                                ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state),
+                                n_state/n_head, n_head, n_past + N),
+                            1, 2, 0, 3),
+                        ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_state/n_head, n_head));
 
             struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
 
@@ -1970,32 +1934,30 @@ static bool whisper_decode(
 
         // projection
         {
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             cur = ggml_mul_mat(ctx0,
                     layer.attn_ln_1_w,
                     cur);
 
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             cur = ggml_add(ctx0,
                     ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
                     cur);
         }
 
-        wctx.use_buf(ctx0, 2);
+        wstate.use_buf(ctx0, 2);
 
         // add the input
         struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
 
         // norm
         {
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
 
-            wctx.use_buf(ctx0, 1);
-
             // cur = ln_0_w*cur + ln_0_b
             cur = ggml_add(ctx0,
                     ggml_mul(ctx0,
@@ -2006,8 +1968,6 @@ static bool whisper_decode(
 
         // cross-attention
         {
-            wctx.use_buf(ctx0, 0);
-
             struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
                     layer.cross_attn_q_w,
                     cur);
@@ -2023,20 +1983,21 @@ static bool whisper_decode(
             // Kcross is already scaled
             struct ggml_tensor * Kcross =
                 ggml_reshape_3d(ctx0,
-                        ggml_view_1d(ctx0, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state),
+                        ggml_view_1d(ctx0, wstate.kv_cross.k, M*n_state, il*M*ggml_element_size(wstate.kv_cross.k)*n_state),
                         n_state/n_head, n_head, M);
 
             struct ggml_tensor * Vcross =
                 ggml_reshape_3d(ctx0,
-                        ggml_view_1d(ctx0, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state),
+                        ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
                         n_state/n_head, n_head, M);
 
-            struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3);
+            struct ggml_tensor * V_trans =
+                ggml_cpy(ctx0,
+                        ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
+                        ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head));
 
             // ------
 
-            wctx.use_buf(ctx0, 1);
-
             struct ggml_tensor * Q =
                 ggml_permute(ctx0,
                         ggml_cpy(ctx0,
@@ -2046,8 +2007,6 @@ static bool whisper_decode(
 
             struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
 
-            wctx.use_buf(ctx0, 0);
-
             // K * Q
             struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
 
@@ -2060,16 +2019,10 @@ static bool whisper_decode(
             // no masking for cross-attention
             //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
 
-            wctx.use_buf(ctx0, 1);
-
             struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
 
-            wctx.use_buf(ctx0, 0);
-
             struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
 
-            wctx.use_buf(ctx0, 1);
-
             struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 
             // cur = KQV_merged.contiguous().view(n_state, N)
@@ -2080,20 +2033,20 @@ static bool whisper_decode(
 
         // projection
         {
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             cur = ggml_mul_mat(ctx0,
                     layer.cross_attn_ln_1_w,
                     cur);
 
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             cur = ggml_add(ctx0,
                     ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
                     cur);
         }
 
-        wctx.use_buf(ctx0, 2);
+        wstate.use_buf(ctx0, 2);
 
         // add the input
         cur = ggml_add(ctx0, cur, inpCA);
@@ -2104,11 +2057,11 @@ static bool whisper_decode(
         {
             // norm
             {
-                wctx.use_buf(ctx0, 0);
+                wstate.use_buf(ctx0, 0);
 
                 cur = ggml_norm(ctx0, inpFF);
 
-                wctx.use_buf(ctx0, 1);
+                wstate.use_buf(ctx0, 1);
 
                 // cur = mlp_ln_w*cur + mlp_ln_b
                 cur = ggml_add(ctx0,
@@ -2118,39 +2071,39 @@ static bool whisper_decode(
                         ggml_repeat(ctx0, layer.mlp_ln_b, cur));
             }
 
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             // fully connected
             cur = ggml_mul_mat(ctx0,
                     layer.mlp_0_w,
                     cur);
 
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             cur = ggml_add(ctx0,
                     ggml_repeat(ctx0, layer.mlp_0_b, cur),
                     cur);
 
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             // GELU activation
             cur = ggml_gelu(ctx0, cur);
 
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             // projection
             cur = ggml_mul_mat(ctx0,
                     layer.mlp_1_w,
                     cur);
 
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             cur = ggml_add(ctx0,
                     ggml_repeat(ctx0, layer.mlp_1_b, cur),
                     cur);
         }
 
-        wctx.use_buf(ctx0, 3);
+        wstate.use_buf(ctx0, 3);
 
         inpL = ggml_add(ctx0, cur, inpFF);
     }
@@ -2159,11 +2112,11 @@ static bool whisper_decode(
 
     // norm
     {
-        wctx.use_buf(ctx0, 0);
+        wstate.use_buf(ctx0, 0);
 
         cur = ggml_norm(ctx0, cur);
 
-        wctx.use_buf(ctx0, 1);
+        wstate.use_buf(ctx0, 1);
 
         cur = ggml_add(ctx0,
                 ggml_mul(ctx0,
@@ -2172,7 +2125,7 @@ static bool whisper_decode(
                 ggml_repeat(ctx0, model.d_ln_b, cur));
     }
 
-    wctx.use_buf(ctx0, 0);
+    wstate.use_buf(ctx0, 0);
 
     // compute logits only for the last token
     // comment this line to compute logits for all N tokens
@@ -2181,7 +2134,7 @@ static bool whisper_decode(
 
     struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
 
-    wctx.use_buf(ctx0, -1);
+    wstate.use_buf(ctx0, -1);
 
     // run the computation
     {
@@ -2200,16 +2153,16 @@ static bool whisper_decode(
     if (N > 1) {
         //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
         //        ggml_used_mem(ctx0)/1024.0/1024.0,
-        //        wctx.get_buf_max_mem(0)/1024.0/1024.0,
-        //        wctx.get_buf_max_mem(1)/1024.0/1024.0,
-        //        wctx.get_buf_max_mem(2)/1024.0/1024.0,
-        //        wctx.get_buf_max_mem(3)/1024.0/1024.0);
+        //        wstate.get_buf_max_mem(0)/1024.0/1024.0,
+        //        wstate.get_buf_max_mem(1)/1024.0/1024.0,
+        //        wstate.get_buf_max_mem(2)/1024.0/1024.0,
+        //        wstate.get_buf_max_mem(3)/1024.0/1024.0);
     }
 
     ggml_free(ctx0);
 
-    wctx.t_decode_us += ggml_time_us() - t_start_us;
-    wctx.n_decode++;
+    wstate.t_decode_us += ggml_time_us() - t_start_us;
+    wstate.n_decode++;
 
     return true;
 }
@@ -2313,7 +2266,7 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
 
 // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
 static bool log_mel_spectrogram(
-        whisper_context & wctx,
+          whisper_state & wstate,
             const float * samples,
               const int   n_samples,
               const int   /*sample_rate*/,
@@ -2433,7 +2386,7 @@ static bool log_mel_spectrogram(
         mel.data[i] = (mel.data[i] + 4.0)/4.0;
     }
 
-    wctx.t_mel_us += ggml_time_us() - t_start_us;
+    wstate.t_mel_us += ggml_time_us() - t_start_us;
 
     return true;
 }
@@ -2507,7 +2460,54 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
 // interface implementation
 //
 
-struct whisper_context * whisper_init_from_file(const char * path_model) {
+struct whisper_state * whisper_init_state(whisper_context * ctx) {
+    whisper_state * state = new whisper_state;
+
+    const size_t scale = ctx->model.hparams.f16 ? 1 : 2;
+
+    if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->wtype, ctx->model.hparams.n_text_ctx)) {
+        fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
+        return nullptr;
+    }
+
+    {
+        const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v);
+        fprintf(stderr, "%s: kv self size  = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
+    }
+
+    if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->wtype, ctx->model.hparams.n_audio_ctx)) {
+        fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
+        return nullptr;
+    }
+
+    {
+        const size_t memory_size = ggml_nbytes(state->kv_cross.k) + ggml_nbytes(state->kv_cross.v);
+        fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
+    }
+
+    state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
+
+    state->logits_id.reserve(ctx->model.hparams.n_vocab);
+
+    // TAGS: WHISPER_DECODER_INIT
+    state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
+
+    state->decoders[0].probs.reserve(ctx->vocab.n_vocab);
+    state->decoders[0].logits.reserve(ctx->vocab.n_vocab);
+    state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
+    state->buf_compute.resize(scale * std::max(MEM_REQ_ENCODE.at(ctx->model.type), MEM_REQ_DECODE.at(ctx->model.type)));
+
+    state->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type));
+    state->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type));
+    state->buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(ctx->model.type));
+    state->buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(ctx->model.type));
+
+    state->rng = std::mt19937(0);
+
+    return state;
+}
+
+struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
     whisper_model_loader loader = {};
 
     fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model);
@@ -2535,10 +2535,16 @@ struct whisper_context * whisper_init_from_file(const char * path_model) {
         fin->close();
     };
 
-    return whisper_init(&loader);
+    auto ctx = whisper_init_no_state(&loader);
+
+    if (ctx) {
+        ctx->path_model = path_model;
+    }
+
+    return ctx;
 }
 
-struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
+struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) {
     struct buf_context {
         uint8_t* buffer;
         size_t size;
@@ -2571,10 +2577,10 @@ struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_s
 
     loader.close = [](void * /*ctx*/) { };
 
-    return whisper_init(&loader);
+    return whisper_init_no_state(&loader);
 }
 
-struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
+struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader) {
     ggml_time_init();
 
     whisper_context * ctx = new whisper_context;
@@ -2591,6 +2597,64 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
     return ctx;
 }
 
+struct whisper_context * whisper_init_from_file(const char * path_model) {
+    whisper_context * ctx = whisper_init_from_file_no_state(path_model);
+    if (!ctx) {
+        return nullptr;
+    }
+
+    ctx->state = whisper_init_state(ctx);
+    if (!ctx->state) {
+        whisper_free(ctx);
+        return nullptr;
+    }
+
+    return ctx;
+}
+
+struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
+    whisper_context * ctx = whisper_init_from_buffer_no_state(buffer, buffer_size);
+    if (!ctx) {
+        return nullptr;
+    }
+
+    ctx->state = whisper_init_state(ctx);
+    if (!ctx->state) {
+        whisper_free(ctx);
+        return nullptr;
+    }
+
+    return ctx;
+}
+
+struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
+    whisper_context * ctx = whisper_init_no_state(loader);
+    if (!ctx) {
+        return nullptr;
+    }
+
+    ctx->state = whisper_init_state(ctx);
+    if (!ctx->state) {
+        whisper_free(ctx);
+        return nullptr;
+    }
+
+    return ctx;
+}
+
+void whisper_free_state(struct whisper_state * state)
+{
+    if (state) {
+        kv_cache_free(state->kv_cross);
+
+        for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
+            kv_cache_free(state->decoders[i].kv_self);
+        }
+
+        delete state;
+    }
+}
+
 void whisper_free(struct whisper_context * ctx) {
     if (ctx) {
         if (ctx->model.ctx) {
@@ -2599,20 +2663,29 @@ void whisper_free(struct whisper_context * ctx) {
         if (ctx->model.buf) {
             delete ctx->model.buf;
         }
-        if (ctx->kv_cross.ctx) {
-            ggml_free(ctx->kv_cross.ctx);
-        }
-        for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
-            if (ctx->decoders[i].kv_self.ctx) {
-                ggml_free(ctx->decoders[i].kv_self.ctx);
-            }
-        }
+
+        whisper_free_state(ctx->state);
+
         delete ctx;
     }
 }
 
+int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
+    if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
+        fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
+        return -1;
+    }
+
+    return 0;
+}
+
 int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
-    if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) {
+    return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
+}
+
+// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
+int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
+    if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, state->mel)) {
         fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
         return -1;
     }
@@ -2622,11 +2695,26 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
 
 // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
 int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
-    if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) {
-        fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
+    return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads);
+}
+
+int whisper_set_mel_with_state(
+        struct whisper_context * /*ctx*/,
+          struct whisper_state * state,
+                   const float * data,
+                           int   n_len,
+                           int   n_mel) {
+    if (n_mel != WHISPER_N_MEL) {
+        fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL);
         return -1;
     }
 
+    state->mel.n_len = n_len;
+    state->mel.n_mel = n_mel;
+
+    state->mel.data.resize(n_len*n_mel);
+    memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float));
+
     return 0;
 }
 
@@ -2635,22 +2723,20 @@ int whisper_set_mel(
         const float * data,
         int n_len,
         int n_mel) {
-    if (n_mel != WHISPER_N_MEL) {
-        fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL);
+    return whisper_set_mel_with_state(ctx, ctx->state, data, n_len, n_mel);
+}
+
+int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
+    if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) {
+        fprintf(stderr, "%s: failed to eval\n", __func__);
         return -1;
     }
 
-    ctx->mel.n_len = n_len;
-    ctx->mel.n_mel = n_mel;
-
-    ctx->mel.data.resize(n_len*n_mel);
-    memcpy(ctx->mel.data.data(), data, n_len*n_mel*sizeof(float));
-
     return 0;
 }
 
 int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
-    if (!whisper_encode(*ctx, offset, n_threads)) {
+    if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) {
         fprintf(stderr, "%s: failed to eval\n", __func__);
         return -1;
     }
@@ -2658,11 +2744,28 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
     return 0;
 }
 
+int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
+    const int selected_decoder_id = 0;
+
+    if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
+        fprintf(stderr, "%s: failed to eval\n", __func__);
+        return 1;
+    }
+
+    return 0;
+}
+
 int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
-    // TODO: add selected_decoder_id to context
+    // TODO: add selected_decoder_id to state
     const int selected_decoder_id = 0;
 
-    if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
+    if (ctx->state == nullptr) {
+        fprintf(stderr, "%s: ERROR state was not loaded.\n", __func__);
+        return false;
+    }
+
+
+    if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
         fprintf(stderr, "%s: failed to eval\n", __func__);
         return 1;
     }
@@ -2720,11 +2823,12 @@ const char * whisper_lang_str(int id) {
     return nullptr;
 }
 
-int whisper_lang_auto_detect(
+int whisper_lang_auto_detect_with_state(
         struct whisper_context * ctx,
-        int offset_ms,
-        int n_threads,
-        float * lang_probs) {
+          struct whisper_state * state,
+                           int   offset_ms,
+                           int   n_threads,
+                         float * lang_probs) {
     const int seek = offset_ms/10;
 
     if (seek < 0) {
@@ -2732,30 +2836,30 @@ int whisper_lang_auto_detect(
         return -1;
     }
 
-    if (seek >= ctx->mel.n_len) {
-        fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10);
+    if (seek >= state->mel.n_len) {
+        fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len*10);
         return -2;
     }
 
     // run the encoder
-    if (whisper_encode(ctx, seek, n_threads) != 0) {
+    if (whisper_encode_with_state(ctx, state, seek, n_threads) != 0) {
         fprintf(stderr, "%s: failed to encode\n", __func__);
         return -6;
     }
 
     const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
 
-    if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) {
+    if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) {
         fprintf(stderr, "%s: failed to decode\n", __func__);
         return -7;
     }
 
-    auto & logits_id = ctx->logits_id;
+    auto & logits_id = state->logits_id;
     logits_id.clear();
 
     for (const auto & kv : g_lang) {
         const auto token_lang = whisper_token_lang(ctx, kv.second.first);
-        logits_id.emplace_back(ctx->logits[token_lang], kv.second.first);
+        logits_id.emplace_back(state->logits[token_lang], kv.second.first);
     }
 
     // sort descending
@@ -2794,8 +2898,85 @@ int whisper_lang_auto_detect(
     return logits_id[0].second;
 }
 
+int whisper_lang_auto_detect(
+        struct whisper_context * ctx,
+                           int   offset_ms,
+                           int   n_threads,
+                         float * lang_probs) {
+    return whisper_lang_auto_detect_with_state(ctx, ctx->state, offset_ms, n_threads, lang_probs);
+}
+
+int whisper_model_n_vocab(struct whisper_context * ctx) {
+    return ctx->model.hparams.n_vocab;
+}
+
+int whisper_model_n_audio_ctx(struct whisper_context * ctx) {
+    return ctx->model.hparams.n_audio_ctx;
+}
+
+int whisper_model_n_audio_state(struct whisper_context * ctx) {
+    return ctx->model.hparams.n_audio_state;
+}
+
+int whisper_model_n_audio_head(struct whisper_context * ctx) {
+    return ctx->model.hparams.n_audio_head;
+}
+
+int whisper_model_n_audio_layer(struct whisper_context * ctx) {
+    return ctx->model.hparams.n_audio_layer;
+}
+
+int whisper_model_n_text_ctx(struct whisper_context * ctx) {
+    return ctx->model.hparams.n_text_ctx;
+}
+
+int whisper_model_n_text_state(struct whisper_context * ctx) {
+    return ctx->model.hparams.n_text_state;
+}
+
+int whisper_model_n_text_head(struct whisper_context * ctx) {
+    return ctx->model.hparams.n_text_head;
+}
+
+int whisper_model_n_text_layer(struct whisper_context * ctx) {
+    return ctx->model.hparams.n_text_layer;
+}
+
+int whisper_model_n_mels(struct whisper_context * ctx) {
+    return ctx->model.hparams.n_mels;
+}
+
+int whisper_model_f16(struct whisper_context * ctx) {
+    return ctx->model.hparams.f16;
+}
+
+int whisper_model_type(struct whisper_context * ctx) {
+    return ctx->model.type;
+}
+
+const char *whisper_model_type_readable(struct whisper_context * ctx) {
+    switch (ctx->model.type) {
+    case e_model::MODEL_TINY:
+        return "tiny";
+    case e_model::MODEL_BASE:
+        return "base";
+    case e_model::MODEL_SMALL:
+        return "small";
+    case e_model::MODEL_MEDIUM:
+        return "medium";
+    case e_model::MODEL_LARGE:
+        return "large";
+    default:
+        return "unknown";
+    }
+}
+
+int whisper_n_len_from_state(struct whisper_state * state) {
+    return state->mel.n_len;
+}
+
 int whisper_n_len(struct whisper_context * ctx) {
-    return ctx->mel.n_len;
+    return ctx->state->mel.n_len;
 }
 
 int whisper_n_vocab(struct whisper_context * ctx) {
@@ -2815,7 +2996,12 @@ int whisper_is_multilingual(struct whisper_context * ctx) {
 }
 
 float * whisper_get_logits(struct whisper_context * ctx) {
-    return ctx->logits.data();
+    return ctx->state->logits.data();
+}
+
+
+float * whisper_get_logits_from_state(struct whisper_state * state) {
+    return state->logits.data();
 }
 
 const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
@@ -2861,24 +3047,29 @@ whisper_token whisper_token_transcribe(void) {
 void whisper_print_timings(struct whisper_context * ctx) {
     const int64_t t_end_us = ggml_time_us();
 
-    const int32_t n_sample = std::max(1, ctx->n_sample);
-    const int32_t n_encode = std::max(1, ctx->n_encode);
-    const int32_t n_decode = std::max(1, ctx->n_decode);
-
     fprintf(stderr, "\n");
-    fprintf(stderr, "%s:     fallbacks = %3d p / %3d h\n", __func__, ctx->n_fail_p, ctx->n_fail_h);
-    fprintf(stderr, "%s:     load time = %8.2f ms\n", __func__, ctx->t_load_us/1000.0f);
-    fprintf(stderr, "%s:      mel time = %8.2f ms\n", __func__, ctx->t_mel_us/1000.0f);
-    fprintf(stderr, "%s:   sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_sample_us, n_sample, 1e-3f*ctx->t_sample_us/n_sample);
-    fprintf(stderr, "%s:   encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_encode_us, n_encode, 1e-3f*ctx->t_encode_us/n_encode);
-    fprintf(stderr, "%s:   decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_decode_us, n_decode, 1e-3f*ctx->t_decode_us/n_decode);
+    fprintf(stderr, "%s:     load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
+    if (ctx->state != nullptr) {
+
+        const int32_t n_sample = std::max(1, ctx->state->n_sample);
+        const int32_t n_encode = std::max(1, ctx->state->n_encode);
+        const int32_t n_decode = std::max(1, ctx->state->n_decode);
+
+        fprintf(stderr, "%s:     fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
+        fprintf(stderr, "%s:      mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
+        fprintf(stderr, "%s:   sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
+        fprintf(stderr, "%s:   encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
+        fprintf(stderr, "%s:   decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
+    }
     fprintf(stderr, "%s:    total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
 }
 
 void whisper_reset_timings(struct whisper_context * ctx) {
-    ctx->t_sample_us = 0;
-    ctx->t_encode_us = 0;
-    ctx->t_decode_us = 0;
+    if (ctx->state != nullptr) {
+        ctx->state->t_sample_us = 0;
+        ctx->state->t_encode_us = 0;
+        ctx->state->t_decode_us = 0;
+    }
 }
 
 const char * whisper_print_system_info(void) {
@@ -2913,7 +3104,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
         /*.duration_ms      =*/ 0,
 
         /*.translate        =*/ false,
-        /*.no_context       =*/ false,
+        /*.no_context       =*/ true,
         /*.single_segment   =*/ false,
         /*.print_special    =*/ false,
         /*.print_progress   =*/ true,
@@ -2942,7 +3133,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
         /*.max_initial_ts   =*/  1.0f,
         /*.length_penalty   =*/ -1.0f,
 
-        /*.temperature_inc  =*/  0.2f,
+        /*.temperature_inc  =*/  0.0f, // TODO: temporary disabled until improve performance
         /*.entropy_thold    =*/  2.4f,
         /*.logprob_thold    =*/ -1.0f,
         /*.no_speech_thold  =*/  0.6f,
@@ -2991,6 +3182,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
 static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
 static void whisper_exp_compute_token_level_timestamps(
         struct whisper_context & ctx,
+          struct whisper_state & state,
                            int   i_segment,
                          float   thold_pt,
                          float   thold_ptsum);
@@ -3023,8 +3215,8 @@ static inline bool should_split_on_word(const char * txt, bool split_on_word) {
 
 // wrap the last segment to max_len characters
 // returns the number of new segments
-static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool split_on_word) {
-    auto segment = ctx.result_all.back();
+static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool split_on_word) {
+    auto segment = state.result_all.back();
 
     int res = 1;
     int acc = 0;
@@ -3046,24 +3238,24 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool
                 trim(text);
             }
 
-            ctx.result_all.back().text = std::move(text);
-            ctx.result_all.back().t1 = token.t0;
-            ctx.result_all.back().tokens.resize(i);
+            state.result_all.back().text = std::move(text);
+            state.result_all.back().t1 = token.t0;
+            state.result_all.back().tokens.resize(i);
 
-            ctx.result_all.push_back({});
-            ctx.result_all.back().t0 = token.t0;
-            ctx.result_all.back().t1 = segment.t1;
+            state.result_all.push_back({});
+            state.result_all.back().t0 = token.t0;
+            state.result_all.back().t1 = segment.t1;
 
             // add tokens [i, end] to the new segment
-            ctx.result_all.back().tokens.insert(
-                    ctx.result_all.back().tokens.end(),
+            state.result_all.back().tokens.insert(
+                state.result_all.back().tokens.end(),
                     segment.tokens.begin() + i,
                     segment.tokens.end());
 
             acc = 0;
             text = "";
 
-            segment = ctx.result_all.back();
+            segment = state.result_all.back();
             i = -1;
 
             res++;
@@ -3076,7 +3268,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool
     if (split_on_word) {
         trim(text);
     }
-    ctx.result_all.back().text = std::move(text);
+    state.result_all.back().text = std::move(text);
 
     return res;
 }
@@ -3093,6 +3285,7 @@ static const std::vector<std::string> non_speech_tokens = {
 // - computes logprobs and probs
 static void whisper_process_logits(
               struct whisper_context & ctx,
+               struct whisper_state  & state,
     const struct whisper_full_params   params,
               struct whisper_decoder & decoder,
                                float   temperature) {
@@ -3111,7 +3304,7 @@ static void whisper_process_logits(
     auto & logprobs = decoder.logprobs;
     {
         logits.resize(n_logits);
-        memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float));
+        memcpy(logits.data(), state.logits.data() + (state.logits.size() - n_logits), n_logits*sizeof(float));
 
         if (temperature > 0.0f) {
             for (int i = 0; i < n_logits; i++) {
@@ -3149,7 +3342,7 @@ static void whisper_process_logits(
         logits[vocab.token_transcribe] = -INFINITY;
 
         if (params.logits_filter_callback) {
-            params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
+            params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
         }
 
         // suppress non-speech tokens
@@ -3310,6 +3503,7 @@ static void whisper_process_logits(
 
 static whisper_token_data whisper_sample_token(
             whisper_context & ctx,
+              whisper_state & state,
       const whisper_decoder & decoder,
                        bool   best) {
     whisper_token_data result = {
@@ -3354,7 +3548,7 @@ static whisper_token_data whisper_sample_token(
     } else {
         std::discrete_distribution<> dist(probs.begin(), probs.end());
 
-        result.id   = dist(ctx.rng);
+        result.id   = dist(state.rng);
         result.p    = probs[result.id];
         result.plog = logprobs[result.id];
     }
@@ -3364,13 +3558,14 @@ static whisper_token_data whisper_sample_token(
         result.pt  = result.p;
     }
 
-    ctx.n_sample++;
+    state.n_sample++;
 
     return result;
 }
 
 static std::vector<whisper_token_data> whisper_sample_token_topk(
             whisper_context & ctx,
+              whisper_state & state,
       const whisper_decoder & decoder,
                         int   k) {
     const auto & vocab = ctx.vocab;
@@ -3381,7 +3576,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
 
     const int n_logits = vocab.n_vocab;
 
-    auto & logits_id = ctx.logits_id;
+    auto & logits_id = state.logits_id;
 
     logits_id.clear();
     for (int i = 0; i < n_logits; ++i) {
@@ -3434,7 +3629,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
         }
     }
 
-    ctx.n_sample++;
+    state.n_sample++;
 
     return result;
 }
@@ -3488,24 +3683,25 @@ static void whisper_sequence_score(
     }
 }
 
-int whisper_full(
+int whisper_full_with_state(
         struct whisper_context * ctx,
-        struct whisper_full_params params,
-        const float * samples,
-        int n_samples) {
+          struct whisper_state * state,
+    struct whisper_full_params   params,
+                   const float * samples,
+                           int   n_samples) {
     // clear old results
-    auto & result_all = ctx->result_all;
+    auto & result_all = state->result_all;
 
     result_all.clear();
 
     // compute log mel spectrogram
     if (params.speed_up) {
-        if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) {
+        if (whisper_pcm_to_mel_phase_vocoder_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
             fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
             return -1;
         }
     } else {
-        if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
+        if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
             fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
             return -2;
         }
@@ -3515,26 +3711,26 @@ int whisper_full(
     if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
         std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
 
-        const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data());
+        const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
         if (lang_id < 0) {
             fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
             return -3;
         }
-        ctx->lang_id = lang_id;
+        state->lang_id = lang_id;
         params.language = whisper_lang_str(lang_id);
 
         fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
     }
 
     if (params.token_timestamps) {
-        ctx->t_beg    = 0;
-        ctx->t_last   = 0;
-        ctx->tid_last = 0;
-        ctx->energy = get_signal_energy(samples, n_samples, 32);
+        state->t_beg    = 0;
+        state->t_last   = 0;
+        state->tid_last = 0;
+        state->energy = get_signal_energy(samples, n_samples, 32);
     }
 
     const int seek_start = params.offset_ms/10;
-    const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len(ctx) : params.duration_ms/10);
+    const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len_from_state(state) : params.duration_ms/10);
 
     // if length of spectrogram is less than 1s (100 samples), then return
     // basically don't process anything that is less than 1s
@@ -3572,10 +3768,10 @@ int whisper_full(
 
     // TAGS: WHISPER_DECODER_INIT
     for (int j = 1; j < n_decoders; j++) {
-        auto & decoder = ctx->decoders[j];
+        auto & decoder = state->decoders[j];
 
         if (decoder.kv_self.ctx == nullptr) {
-            decoder.kv_self = ctx->decoders[0].kv_self;
+            decoder.kv_self = state->decoders[0].kv_self;
             if (!kv_cache_reinit(decoder.kv_self)) {
                 fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
                 return -4;
@@ -3583,7 +3779,7 @@ int whisper_full(
 
             WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
 
-            decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity());
+            decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
 
             decoder.probs.resize   (ctx->vocab.n_vocab);
             decoder.logits.resize  (ctx->vocab.n_vocab);
@@ -3592,7 +3788,7 @@ int whisper_full(
     }
 
     // the accumulated text context so far
-    auto & prompt_past = ctx->prompt_past;
+    auto & prompt_past = state->prompt_past;
     if (params.no_context) {
         prompt_past.clear();
     }
@@ -3611,13 +3807,13 @@ int whisper_full(
         fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
         return -5;
     }
-    ctx->exp_n_audio_ctx = params.audio_ctx;
+    state->exp_n_audio_ctx = params.audio_ctx;
 
     // these tokens determine the task that will be performed
     std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
     if (whisper_is_multilingual(ctx)) {
         const int lang_id = whisper_lang_id(params.language);
-        ctx->lang_id = lang_id;
+        state->lang_id = lang_id;
         prompt_init.push_back(whisper_token_lang(ctx, lang_id));
         if (params.translate) {
             prompt_init.push_back(whisper_token_translate());
@@ -3669,14 +3865,14 @@ int whisper_full(
         }
 
         if (params.encoder_begin_callback) {
-            if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
+            if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) {
                 fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
                 break;
             }
         }
 
         // encode audio features starting at offset seek
-        if (!whisper_encode(*ctx, seek, params.n_threads)) {
+        if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) {
             fprintf(stderr, "%s: failed to encode\n", __func__);
             return -6;
         }
@@ -3717,7 +3913,7 @@ int whisper_full(
 
             // TAGS: WHISPER_DECODER_INIT
             for (int j = 0; j < n_decoders_cur; ++j) {
-                auto & decoder = ctx->decoders[j];
+                auto & decoder = state->decoders[j];
 
                 decoder.kv_self.n = 0;
 
@@ -3759,7 +3955,7 @@ int whisper_full(
                 }
                 WHISPER_PRINT_DEBUG("\n\n");
 
-                if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
+                if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
                     fprintf(stderr, "%s: failed to decode\n", __func__);
                     return -7;
                 }
@@ -3767,24 +3963,24 @@ int whisper_full(
                 {
                     const int64_t t_start_sample_us = ggml_time_us();
 
-                    whisper_process_logits(*ctx, params, ctx->decoders[0], t_cur);
+                    whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur);
 
-                    ctx->decoders[0].kv_self.n += prompt.size();
+                    state->decoders[0].kv_self.n += prompt.size();
 
                     for (int j = 1; j < n_decoders_cur; ++j) {
-                        auto & decoder = ctx->decoders[j];
+                        auto & decoder = state->decoders[j];
 
-                        memcpy(decoder.kv_self.k->data, ctx->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
-                        memcpy(decoder.kv_self.v->data, ctx->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
+                        memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
+                        memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
 
                         decoder.kv_self.n += prompt.size();
 
-                        memcpy(decoder.probs.data(),    ctx->decoders[0].probs.data(),    decoder.probs.size()*sizeof(decoder.probs[0]));
-                        memcpy(decoder.logits.data(),   ctx->decoders[0].logits.data(),   decoder.logits.size()*sizeof(decoder.logits[0]));
-                        memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
+                        memcpy(decoder.probs.data(), state->decoders[0].probs.data(),    decoder.probs.size()*sizeof(decoder.probs[0]));
+                        memcpy(decoder.logits.data(), state->decoders[0].logits.data(),   decoder.logits.size()*sizeof(decoder.logits[0]));
+                        memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
                     }
 
-                    ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+                    state->t_sample_us += ggml_time_us() - t_start_sample_us;
                 }
             }
 
@@ -3795,7 +3991,7 @@ int whisper_full(
                 if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
                     kv_bufs.resize(n_decoders_cur);
                     for (int j = 0; j < n_decoders_cur; ++j) {
-                        auto & decoder = ctx->decoders[j];
+                        auto & decoder = state->decoders[j];
 
                         if (decoder.completed || decoder.failed) {
                             continue;
@@ -3813,7 +4009,7 @@ int whisper_full(
 
                 // generate new sequence candidates for each decoder
                 for (int j = 0; j < n_decoders_cur; ++j) {
-                    auto & decoder = ctx->decoders[j];
+                    auto & decoder = state->decoders[j];
 
                     if (decoder.completed || decoder.failed) {
                         continue;
@@ -3823,16 +4019,16 @@ int whisper_full(
                         case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
                             {
                                 if (t_cur < 1e-6f) {
-                                    decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
+                                    decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, true));
                                 } else {
-                                    decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
+                                    decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, false));
                                 }
 
                                 decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
                             } break;
                         case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
                             {
-                                const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
+                                const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size);
 
                                 for (const auto & token : tokens_new) {
                                     beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence });
@@ -3857,7 +4053,7 @@ int whisper_full(
                     uint32_t cur_c = 0;
 
                     for (int j = 0; j < n_decoders_cur; ++j) {
-                        auto & decoder = ctx->decoders[j];
+                        auto & decoder = state->decoders[j];
 
                         if (decoder.completed || decoder.failed) {
                             continue;
@@ -3886,7 +4082,7 @@ int whisper_full(
                 // - check if the sequence is failed
                 // - update sliding window based on timestamp tokens
                 for (int j = 0; j < n_decoders_cur; ++j) {
-                    auto & decoder = ctx->decoders[j];
+                    auto & decoder = state->decoders[j];
 
                     if (decoder.completed || decoder.failed) {
                         continue;
@@ -3968,7 +4164,7 @@ int whisper_full(
                     bool completed_all = true;
 
                     for (int j = 0; j < n_decoders_cur; ++j) {
-                        auto & decoder = ctx->decoders[j];
+                        auto & decoder = state->decoders[j];
 
                         if (decoder.completed || decoder.failed) {
                             continue;
@@ -3982,11 +4178,11 @@ int whisper_full(
                     }
                 }
 
-                ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+                state->t_sample_us += ggml_time_us() - t_start_sample_us;
 
                 // obtain logits for the next token
                 for (int j = 0; j < n_decoders_cur; ++j) {
-                    auto & decoder = ctx->decoders[j];
+                    auto & decoder = state->decoders[j];
 
                     if (decoder.failed || decoder.completed) {
                         continue;
@@ -3997,7 +4193,7 @@ int whisper_full(
 
                     //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
 
-                    if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
+                    if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
                         fprintf(stderr, "%s: failed to decode\n", __func__);
                         return -8;
                     }
@@ -4005,11 +4201,11 @@ int whisper_full(
                     {
                         const int64_t t_start_sample_us = ggml_time_us();
 
-                        whisper_process_logits(*ctx, params, decoder, t_cur);
+                        whisper_process_logits(*ctx, *state, params, decoder, t_cur);
 
                         ++decoder.kv_self.n;
 
-                        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+                        state->t_sample_us += ggml_time_us() - t_start_sample_us;
                     }
                 }
             }
@@ -4019,7 +4215,7 @@ int whisper_full(
                 double best_score = -INFINITY;
 
                 for (int j = 0; j < n_decoders_cur; ++j) {
-                    auto & decoder = ctx->decoders[j];
+                    auto & decoder = state->decoders[j];
 
                     if (decoder.failed) {
                         continue;
@@ -4036,7 +4232,7 @@ int whisper_full(
                                 __func__, j, decoder.sequence.entropy, params.entropy_thold);
 
                         decoder.failed = true;
-                        ctx->n_fail_h++;
+                        state->n_fail_h++;
 
                         continue;
                     }
@@ -4054,11 +4250,11 @@ int whisper_full(
             {
                 bool success = true;
 
-                const auto & decoder = ctx->decoders[best_decoder_id];
+                const auto & decoder = state->decoders[best_decoder_id];
 
                 if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
                     success = false;
-                    ctx->n_fail_p++;
+                    state->n_fail_p++;
                 }
 
                 if (success) {
@@ -4075,7 +4271,7 @@ int whisper_full(
 
         // output results through a user-provided callback
         {
-            const auto & best_decoder = ctx->decoders[best_decoder_id];
+            const auto & best_decoder = state->decoders[best_decoder_id];
 
             const auto seek_delta = best_decoder.seek_delta;
             const auto result_len = best_decoder.sequence.result_len;
@@ -4138,14 +4334,14 @@ int whisper_full(
 
                             if (params.token_timestamps) {
                                 whisper_exp_compute_token_level_timestamps(
-                                        *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
+                                        *ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
 
                                 if (params.max_len > 0) {
-                                    n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
+                                    n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
                                 }
                             }
                             if (params.new_segment_callback) {
-                                params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
+                                params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
                             }
                         }
                         text = "";
@@ -4182,14 +4378,14 @@ int whisper_full(
 
                     if (params.token_timestamps) {
                         whisper_exp_compute_token_level_timestamps(
-                                *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
+                                *ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
 
                         if (params.max_len > 0) {
-                            n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
+                            n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
                         }
                     }
                     if (params.new_segment_callback) {
-                        params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
+                        params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
                     }
                 }
             }
@@ -4204,6 +4400,15 @@ int whisper_full(
     return 0;
 }
 
+
+int whisper_full(
+        struct whisper_context * ctx,
+    struct whisper_full_params   params,
+                   const float * samples,
+                           int   n_samples) {
+    return whisper_full_with_state(ctx, ctx->state, params, samples, n_samples);
+}
+
 int whisper_full_parallel(
         struct whisper_context * ctx,
         struct whisper_full_params params,
@@ -4213,40 +4418,10 @@ int whisper_full_parallel(
     if (n_processors == 1) {
         return whisper_full(ctx, params, samples, n_samples);
     }
-
     int ret = 0;
 
-    // prepare separate contexts for each thread
-    std::vector<struct whisper_context> ctxs(n_processors - 1);
-
-    for (int i = 0; i < n_processors - 1; ++i) {
-        auto & ctx_p = ctxs[i];
-
-        ctx_p = *ctx;
-
-        ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
-
-        ctx_p.logits_id.reserve(ctx_p.vocab.n_vocab);
-
-        if (!kv_cache_reinit(ctx_p.kv_cross)) {
-            fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i);
-            return false;
-        }
-
-        // TAGS: WHISPER_DECODER_INIT
-        for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
-            if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) {
-                fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i);
-                return false;
-            }
-
-            ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx);
-
-            ctx_p.decoders[j].probs.reserve   (ctx_p.vocab.n_vocab);
-            ctx_p.decoders[j].logits.reserve  (ctx_p.vocab.n_vocab);
-            ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab);
-        }
-    }
+    // prepare separate states for each thread
+    std::vector<whisper_state*> states;
 
     const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000;
     const int n_samples_per_processor = (n_samples - offset_samples)/n_processors;
@@ -4256,6 +4431,9 @@ int whisper_full_parallel(
 
     std::vector<std::thread> workers(n_processors - 1);
     for (int i = 0; i < n_processors - 1; ++i) {
+        // create a new state for each thread
+        states.push_back(whisper_init_state(ctx));
+
         const int start_samples = offset_samples + (i + 1)*n_samples_per_processor;
         const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor;
 
@@ -4268,13 +4446,17 @@ int whisper_full_parallel(
         params_cur.new_segment_callback = nullptr;
         params_cur.new_segment_callback_user_data = nullptr;
 
-        workers[i] = std::thread(whisper_full, &ctxs[i], std::move(params_cur), samples + start_samples, n_samples_cur);
+        workers[i] = std::thread(whisper_full_with_state, ctx, states[i], std::move(params_cur), samples + start_samples, n_samples_cur);
     }
 
     {
         auto params_cur = params;
 
-        ret = whisper_full(ctx, std::move(params_cur), samples, offset_samples + n_samples_per_processor);
+        // We need to disable the print real-time for this one as well, otherwise it will show only for the first chunk.
+        params_cur.print_realtime = false;
+
+        // Run the first transformation using default state but only for the first chunk.
+        ret = whisper_full_with_state(ctx, ctx->state, std::move(params_cur), samples, offset_samples + n_samples_per_processor);
     }
 
     for (int i = 0; i < n_processors - 1; ++i) {
@@ -4283,45 +4465,43 @@ int whisper_full_parallel(
 
     const int64_t offset_t = (int64_t) params.offset_ms/10.0;
 
-    // combine results into ctx->result_all
+    // combine results into result_state->result_all from all other states
     for (int i = 0; i < n_processors - 1; ++i) {
-        auto & results_i = ctxs[i].result_all;
+        auto& results_i = states[i]->result_all;
 
-        for (auto & result : results_i) {
+        for (auto& result : results_i) {
             // correct the segment timestamp taking into account the offset
-            result.t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
-            result.t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
+            result.t0 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t;
+            result.t1 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t;
+
 
             // make sure that segments are not overlapping
-            if (!ctx->result_all.empty()) {
-                result.t0 = std::max(result.t0, ctx->result_all.back().t1);
+            if (!ctx->state->result_all.empty()) {
+                result.t0 = std::max(result.t0, ctx->state->result_all.back().t1);
             }
 
-            ctx->result_all.push_back(std::move(result));
+            ctx->state->result_all.push_back(std::move(result));
 
             // call the new_segment_callback for each segment
             if (params.new_segment_callback) {
-                params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data);
+                params.new_segment_callback(ctx, ctx->state, 1, params.new_segment_callback_user_data);
             }
         }
 
-        ctx->t_mel_us    += ctxs[i].t_mel_us;
-        ctx->t_sample_us += ctxs[i].t_sample_us;
-        ctx->t_encode_us += ctxs[i].t_encode_us;
-        ctx->t_decode_us += ctxs[i].t_decode_us;
+        ctx->state->t_mel_us += states[i]->t_mel_us;
 
-        kv_cache_free(ctx->kv_cross);
+        ctx->state->t_sample_us += states[i]->t_sample_us;
+        ctx->state->t_encode_us += states[i]->t_encode_us;
+        ctx->state->t_decode_us += states[i]->t_decode_us;
 
-        for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
-            kv_cache_free(ctx->decoders[j].kv_self);
-        }
+        whisper_free_state(states[i]);
     }
 
     // average the timings
-    ctx->t_mel_us    /= n_processors;
-    ctx->t_sample_us /= n_processors;
-    ctx->t_encode_us /= n_processors;
-    ctx->t_decode_us /= n_processors;
+    ctx->state->t_mel_us    /= n_processors;
+    ctx->state->t_sample_us /= n_processors;
+    ctx->state->t_encode_us /= n_processors;
+    ctx->state->t_decode_us /= n_processors;
 
     // print information about the audio boundaries
     fprintf(stderr, "\n");
@@ -4334,44 +4514,84 @@ int whisper_full_parallel(
     return ret;
 }
 
+int whisper_full_n_segments_from_state(struct whisper_state * state) {
+    return state->result_all.size();
+}
+
 int whisper_full_n_segments(struct whisper_context * ctx) {
-    return ctx->result_all.size();
+    return ctx->state->result_all.size();
+}
+
+int whisper_full_lang_id_from_state(struct whisper_state * state) {
+    return state->lang_id;
 }
 
 int whisper_full_lang_id(struct whisper_context * ctx) {
-    return ctx->lang_id;
+    return ctx->state->lang_id;
+}
+
+int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) {
+    return state->result_all[i_segment].t0;
 }
 
 int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
-    return ctx->result_all[i_segment].t0;
+    return ctx->state->result_all[i_segment].t0;
+}
+
+int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) {
+    return state->result_all[i_segment].t1;
 }
 
 int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
-    return ctx->result_all[i_segment].t1;
+    return ctx->state->result_all[i_segment].t1;
+}
+
+const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment) {
+    return state->result_all[i_segment].text.c_str();
 }
 
 const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) {
-    return ctx->result_all[i_segment].text.c_str();
+    return ctx->state->result_all[i_segment].text.c_str();
+}
+
+int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment) {
+    return state->result_all[i_segment].tokens.size();
 }
 
 int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) {
-    return ctx->result_all[i_segment].tokens.size();
+    return ctx->state->result_all[i_segment].tokens.size();
+}
+
+const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token) {
+    return ctx->vocab.id_to_token[state->result_all[i_segment].tokens[i_token].id].c_str();
+}
+
+const char* whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) {
+    return ctx->vocab.id_to_token[ctx->state->result_all[i_segment].tokens[i_token].id].c_str();
 }
 
-const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) {
-    return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
+whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token) {
+    return state->result_all[i_segment].tokens[i_token].id;
 }
 
 whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) {
-    return ctx->result_all[i_segment].tokens[i_token].id;
+    return ctx->state->result_all[i_segment].tokens[i_token].id;
+}
+
+struct whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token) {
+    return state->result_all[i_segment].tokens[i_token];
 }
 
 struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) {
-    return ctx->result_all[i_segment].tokens[i_token];
+    return ctx->state->result_all[i_segment].tokens[i_token];
+}
+
+float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token) {
+    return state->result_all[i_segment].tokens[i_token].p;
 }
 
 float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
-    return ctx->result_all[i_segment].tokens[i_token].p;
+    return ctx->state->result_all[i_segment].tokens[i_token].p;
 }
 
 // =================================================================================================
@@ -4382,6 +4602,15 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int
 //
 
 WHISPER_API int whisper_bench_memcpy(int n_threads) {
+    fputs(whisper_bench_memcpy_str(n_threads), stderr);
+    return 0;
+}
+
+WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
+    static std::string s;
+    s = "";
+    char strbuf[256];
+
     ggml_time_init();
 
     size_t n    = 50;
@@ -4411,7 +4640,8 @@ WHISPER_API int whisper_bench_memcpy(int n_threads) {
         src[0] = rand();
     }
 
-    fprintf(stderr, "memcpy: %.2f GB/s\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu));
+    snprintf(strbuf, sizeof(strbuf), "memcpy: %.2f GB/s\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu));
+    s += strbuf;
 
     // needed to prevent the compile from optimizing the memcpy away
     {
@@ -4419,16 +4649,26 @@ WHISPER_API int whisper_bench_memcpy(int n_threads) {
 
         for (size_t i = 0; i < size; i++) sum += dst[i];
 
-        fprintf(stderr, "sum:    %s %f\n", sum == -536870910.00 ? "ok" : "error", sum);
+        snprintf(strbuf, sizeof(strbuf), "sum:    %s %f\n", sum == -536870910.00 ? "ok" : "error", sum);
+        s += strbuf;
     }
 
     free(src);
     free(dst);
 
-    return 0;
+    return s.c_str();
 }
 
 WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
+    fputs(whisper_bench_ggml_mul_mat_str(n_threads), stderr);
+    return 0;
+}
+
+WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
+    static std::string s;
+    s = "";
+    char strbuf[256];
+
     ggml_time_init();
 
     const int n_max = 128;
@@ -4504,11 +4744,12 @@ WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
             s = ((2.0*N*N*N*n)/tsum)*1e-9;
         }
 
-        fprintf(stderr, "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n",
+        snprintf(strbuf, sizeof(strbuf), "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n",
             N, N, s_fp16, n_fp16, s_fp32, n_fp32);
+        s += strbuf;
     }
 
-    return 0;
+    return s.c_str();
 }
 
 // =================================================================================================
@@ -4583,13 +4824,14 @@ static std::vector<float> get_signal_energy(const float * signal, int n_samples,
 
 static void whisper_exp_compute_token_level_timestamps(
         struct whisper_context & ctx,
+          struct whisper_state & state,
                            int   i_segment,
                          float   thold_pt,
                          float   thold_ptsum) {
-    auto & segment = ctx.result_all[i_segment];
+    auto & segment = state.result_all[i_segment];
     auto & tokens  = segment.tokens;
 
-    const int n_samples = ctx.energy.size();
+    const int n_samples = state.energy.size();
 
     if (n_samples == 0) {
         fprintf(stderr, "%s: no signal data available\n", __func__);
@@ -4612,9 +4854,9 @@ static void whisper_exp_compute_token_level_timestamps(
         return;
     }
 
-    auto & t_beg    = ctx.t_beg;
-    auto & t_last   = ctx.t_last;
-    auto & tid_last = ctx.tid_last;
+    auto & t_beg    = state.t_beg;
+    auto & t_last   = state.t_last;
+    auto & tid_last = state.tid_last;
 
     for (int j = 0; j < n; ++j) {
         auto & token = tokens[j];
@@ -4737,15 +4979,15 @@ static void whisper_exp_compute_token_level_timestamps(
             float sum = 0.0f;
 
             for (int k = ss0; k < ss1; k++) {
-                sum += ctx.energy[k];
+                sum += state.energy[k];
             }
 
             const float thold = 0.5*sum/ns;
 
             {
                 int k = s0;
-                if (ctx.energy[k] > thold && j > 0) {
-                    while (k > 0 && ctx.energy[k] > thold) {
+                if (state.energy[k] > thold && j > 0) {
+                    while (k > 0 && state.energy[k] > thold) {
                         k--;
                     }
                     tokens[j].t0 = sample_to_timestamp(k);
@@ -4755,7 +4997,7 @@ static void whisper_exp_compute_token_level_timestamps(
                         s0 = k;
                     }
                 } else {
-                    while (ctx.energy[k] < thold && k < s1) {
+                    while (state.energy[k] < thold && k < s1) {
                         k++;
                     }
                     s0 = k;
@@ -4765,8 +5007,8 @@ static void whisper_exp_compute_token_level_timestamps(
 
             {
                 int k = s1;
-                if (ctx.energy[k] > thold) {
-                    while (k < n_samples - 1 && ctx.energy[k] > thold) {
+                if (state.energy[k] > thold) {
+                    while (k < n_samples - 1 && state.energy[k] > thold) {
                         k++;
                     }
                     tokens[j].t1 = sample_to_timestamp(k);
@@ -4776,7 +5018,7 @@ static void whisper_exp_compute_token_level_timestamps(
                         s1 = k;
                     }
                 } else {
-                    while (ctx.energy[k] < thold && k > s0) {
+                    while (state.energy[k] < thold && k > s0) {
                         k--;
                     }
                     s1 = k;
index 3eb8d08420948ca35533d19ffbd9047a55a78493..fc107108ad810da73e2fc4d8a0339e73bccf749e 100644 (file)
@@ -66,6 +66,7 @@ extern "C" {
     //
 
     struct whisper_context;
+    struct whisper_state;
 
     typedef int whisper_token;
 
@@ -101,11 +102,20 @@ extern "C" {
     WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size);
     WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader);
 
-    // Frees all memory allocated by the model.
-    WHISPER_API void whisper_free(struct whisper_context * ctx);
+    // These are the same as the above, but the internal state of the context is not allocated automatically
+    // It is the responsibility of the caller to allocate the state using whisper_init_state() (#523)
+    WHISPER_API struct whisper_context * whisper_init_from_file_no_state(const char * path_model);
+    WHISPER_API struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size);
+    WHISPER_API struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader);
+
+    WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx);
+
+    // Frees all allocated memory
+    WHISPER_API void whisper_free      (struct whisper_context * ctx);
+    WHISPER_API void whisper_free_state(struct whisper_state * state);
 
     // Convert RAW PCM audio to log mel spectrogram.
-    // The resulting spectrogram is stored inside the provided whisper context.
+    // The resulting spectrogram is stored inside the default state of the provided whisper context.
     // Returns 0 on success
     WHISPER_API int whisper_pcm_to_mel(
             struct whisper_context * ctx,
@@ -113,17 +123,30 @@ extern "C" {
                                int   n_samples,
                                int   n_threads);
 
-    // Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2. 
-    // The resulting spectrogram is stored inside the provided whisper context.
+    WHISPER_API int whisper_pcm_to_mel_with_state(
+            struct whisper_context * ctx,
+              struct whisper_state * state,
+                       const float * samples,
+                               int   n_samples,
+                               int   n_threads);
+
+    // Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2.
+    // The resulting spectrogram is stored inside the default state of the provided whisper context.
     // Returns 0 on success
     WHISPER_API int whisper_pcm_to_mel_phase_vocoder(
-        struct whisper_context* ctx,
-        const float* samples,
-        int   n_samples,
-        int   n_threads);
-
-
-    // This can be used to set a custom log mel spectrogram inside the provided whisper context.
+        struct whisper_context * ctx,
+                   const float * samples,
+                           int   n_samples,
+                           int   n_threads);
+
+    WHISPER_API int whisper_pcm_to_mel_phase_vocoder_with_state(
+        struct whisper_context * ctx,
+          struct whisper_state * state,
+                   const float * samples,
+                           int   n_samples,
+                           int   n_threads);
+
+    // This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context.
     // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
     // n_mel must be 80
     // Returns 0 on success
@@ -133,7 +156,14 @@ extern "C" {
                                int   n_len,
                                int   n_mel);
 
-    // Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
+    WHISPER_API int whisper_set_mel_with_state(
+            struct whisper_context * ctx,
+              struct whisper_state * state,
+                       const float * data,
+                               int   n_len,
+                               int   n_mel);
+
+    // Run the Whisper encoder on the log mel spectrogram stored inside the default state in the provided whisper context.
     // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
     // offset can be used to specify the offset of the first frame in the spectrogram.
     // Returns 0 on success
@@ -142,6 +172,12 @@ extern "C" {
                                int   offset,
                                int   n_threads);
 
+    WHISPER_API int whisper_encode_with_state(
+            struct whisper_context * ctx,
+              struct whisper_state * state,
+                               int   offset,
+                               int   n_threads);
+
     // Run the Whisper decoder to obtain the logits and probabilities for the next token.
     // Make sure to call whisper_encode() first.
     // tokens + n_tokens is the provided context for the decoder.
@@ -155,6 +191,14 @@ extern "C" {
                                int   n_past,
                                int   n_threads);
 
+    WHISPER_API int whisper_decode_with_state(
+            struct whisper_context * ctx,
+              struct whisper_state * state,
+               const whisper_token * tokens,
+                               int   n_tokens,
+                               int   n_past,
+                               int   n_threads);
+
     // Convert the provided text into tokens.
     // The tokens pointer must be large enough to hold the resulting tokens.
     // Returns the number of tokens on success, no more than n_max_tokens
@@ -190,20 +234,44 @@ extern "C" {
                                int   n_threads,
                              float * lang_probs);
 
-    WHISPER_API int whisper_n_len          (struct whisper_context * ctx); // mel length
-    WHISPER_API int whisper_n_vocab        (struct whisper_context * ctx);
-    WHISPER_API int whisper_n_text_ctx     (struct whisper_context * ctx);
-    WHISPER_API int whisper_n_audio_ctx    (struct whisper_context * ctx);
-    WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
+    WHISPER_API int whisper_lang_auto_detect_with_state(
+            struct whisper_context * ctx,
+              struct whisper_state * state,
+                               int   offset_ms,
+                               int   n_threads,
+                             float * lang_probs);
+
+    WHISPER_API int whisper_n_len           (struct whisper_context * ctx); // mel length
+    WHISPER_API int whisper_n_len_from_state(struct whisper_state * state); // mel length
+    WHISPER_API int whisper_n_vocab         (struct whisper_context * ctx);
+    WHISPER_API int whisper_n_text_ctx      (struct whisper_context * ctx);
+    WHISPER_API int whisper_n_audio_ctx     (struct whisper_context * ctx);
+    WHISPER_API int whisper_is_multilingual (struct whisper_context * ctx);
+
+    WHISPER_API int whisper_model_n_vocab      (struct whisper_context * ctx);
+    WHISPER_API int whisper_model_n_audio_ctx  (struct whisper_context * ctx);
+    WHISPER_API int whisper_model_n_audio_state(struct whisper_context * ctx);
+    WHISPER_API int whisper_model_n_audio_head (struct whisper_context * ctx);
+    WHISPER_API int whisper_model_n_audio_layer(struct whisper_context * ctx);
+    WHISPER_API int whisper_model_n_text_ctx   (struct whisper_context * ctx);
+    WHISPER_API int whisper_model_n_text_state (struct whisper_context * ctx);
+    WHISPER_API int whisper_model_n_text_head  (struct whisper_context * ctx);
+    WHISPER_API int whisper_model_n_text_layer (struct whisper_context * ctx);
+    WHISPER_API int whisper_model_n_mels       (struct whisper_context * ctx);
+    WHISPER_API int whisper_model_f16          (struct whisper_context * ctx);
+    WHISPER_API int whisper_model_type         (struct whisper_context * ctx);
 
     // Token logits obtained from the last call to whisper_decode()
     // The logits for the last token are stored in the last row
     // Rows: n_tokens
     // Cols: n_vocab
-    WHISPER_API float * whisper_get_logits(struct whisper_context * ctx);
+    WHISPER_API float * whisper_get_logits           (struct whisper_context * ctx);
+    WHISPER_API float * whisper_get_logits_from_state(struct whisper_state * state);
 
     // Token Id -> String. Uses the vocabulary in the provided context
     WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
+    WHISPER_API const char * whisper_model_type_readable(struct whisper_context * ctx);
+
 
     // Special tokens
     WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx);
@@ -218,7 +286,7 @@ extern "C" {
     WHISPER_API whisper_token whisper_token_translate (void);
     WHISPER_API whisper_token whisper_token_transcribe(void);
 
-    // Performance information
+    // Performance information from the default state.
     WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
     WHISPER_API void whisper_reset_timings(struct whisper_context * ctx);
 
@@ -236,18 +304,19 @@ extern "C" {
     // Text segment callback
     // Called on every newly generated text segment
     // Use the whisper_full_...() functions to obtain the text segments
-    typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data);
+    typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data);
 
     // Encoder begin callback
     // If not NULL, called before the encoder starts
     // If it returns false, the computation is aborted
-    typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
+    typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
 
     // Logits filter callback
     // Can be used to modify the logits before sampling
     // If not NULL, called after applying temperature to logits
     typedef void (*whisper_logits_filter_callback)(
             struct whisper_context * ctx,
+              struct whisper_state * state,
           const whisper_token_data * tokens,
                                int   n_tokens,
                              float * logits,
@@ -334,6 +403,7 @@ extern "C" {
     WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
 
     // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
+    // Not thread safe for same context
     // Uses the specified decoding strategy to obtain the text.
     WHISPER_API int whisper_full(
                 struct whisper_context * ctx,
@@ -341,7 +411,16 @@ extern "C" {
                            const float * samples,
                                    int   n_samples);
 
-    // Split the input audio in chunks and process each chunk separately using whisper_full()
+    WHISPER_API int whisper_full_with_state(
+                struct whisper_context * ctx,
+                  struct whisper_state * state,
+            struct whisper_full_params   params,
+                           const float * samples,
+                                   int   n_samples);
+
+    // Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
+    // Result is stored in the default state of the context
+    // Not thread safe if executed in parallel on the same context.
     // It seems this approach can offer some speedup in some cases.
     // However, the transcription accuracy can be worse at the beginning and end of each chunk.
     WHISPER_API int whisper_full_parallel(
@@ -351,40 +430,56 @@ extern "C" {
                                    int   n_samples,
                                    int   n_processors);
 
-    // Number of generated text segments.
+    // Number of generated text segments
     // A segment can be a few words, a sentence, or even a paragraph.
-    WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx);
+    WHISPER_API int whisper_full_n_segments           (struct whisper_context * ctx);
+    WHISPER_API int whisper_full_n_segments_from_state(struct whisper_state * state);
 
-    // Language id associated with the current context
+    // Language id associated with the context's default state
     WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx);
 
-    // Get the start and end time of the specified segment.
-    WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment);
-    WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment);
+    // Language id associated with the provided state
+    WHISPER_API int whisper_full_lang_id_from_state(struct whisper_state * state);
+
+    // Get the start and end time of the specified segment
+    WHISPER_API int64_t whisper_full_get_segment_t0           (struct whisper_context * ctx, int i_segment);
+    WHISPER_API int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment);
+
+    WHISPER_API int64_t whisper_full_get_segment_t1           (struct whisper_context * ctx, int i_segment);
+    WHISPER_API int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment);
+
+    // Get the text of the specified segment
+    WHISPER_API const char * whisper_full_get_segment_text           (struct whisper_context * ctx, int i_segment);
+    WHISPER_API const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment);
 
-    // Get the text of the specified segment.
-    WHISPER_API const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment);
+    // Get number of tokens in the specified segment
+    WHISPER_API int whisper_full_n_tokens           (struct whisper_context * ctx, int i_segment);
+    WHISPER_API int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment);
 
-    // Get number of tokens in the specified segment.
-    WHISPER_API int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment);
+    // Get the token text of the specified token in the specified segment
+    WHISPER_API const char * whisper_full_get_token_text           (struct whisper_context * ctx, int i_segment, int i_token);
+    WHISPER_API const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token);
 
-    // Get the token text of the specified token in the specified segment.
-    WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token);
-    WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token);
+    WHISPER_API whisper_token whisper_full_get_token_id           (struct whisper_context * ctx, int i_segment, int i_token);
+    WHISPER_API whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token);
 
-    // Get token data for the specified token in the specified segment.
+    // Get token data for the specified token in the specified segment
     // This contains probabilities, timestamps, etc.
-    WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token);
+    WHISPER_API whisper_token_data whisper_full_get_token_data           (struct whisper_context * ctx, int i_segment, int i_token);
+    WHISPER_API whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token);
 
-    // Get the probability of the specified token in the specified segment.
-    WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
+    // Get the probability of the specified token in the specified segment
+    WHISPER_API float whisper_full_get_token_p           (struct whisper_context * ctx, int i_segment, int i_token);
+    WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token);
 
     ////////////////////////////////////////////////////////////////////////////
 
     // Temporary helpers needed for exposing ggml interface
 
     WHISPER_API int whisper_bench_memcpy(int n_threads);
+    WHISPER_API const char * whisper_bench_memcpy_str(int n_threads);
     WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads);
+    WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads);
 
 #ifdef __cplusplus
 }
index 18f317bec04dee0df2aede91ea477ad6ff6f621f..335230f9f0bb26e669b4d7156da26ca588d516be 100644 (file)
@@ -198,6 +198,8 @@ struct ggml_object;
 struct ggml_context;
 
 enum ggml_type {
+    GGML_TYPE_Q4_0,
+    GGML_TYPE_Q4_1,
     GGML_TYPE_I8,
     GGML_TYPE_I16,
     GGML_TYPE_I32,
@@ -226,7 +228,9 @@ enum ggml_op {
     GGML_OP_STEP,
     GGML_OP_RELU,
     GGML_OP_GELU,
+    GGML_OP_SILU,
     GGML_OP_NORM, // normalize
+    GGML_OP_RMS_NORM,
 
     GGML_OP_MUL_MAT,
 
@@ -326,7 +330,10 @@ void ggml_print_objects(const struct ggml_context * ctx);
 int    ggml_nelements(const struct ggml_tensor * tensor);
 size_t ggml_nbytes   (const struct ggml_tensor * tensor);
 
-size_t ggml_type_size   (enum ggml_type type);
+int    ggml_blck_size (enum ggml_type type);
+size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
+float  ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float
+
 size_t ggml_element_size(const struct ggml_tensor * tensor);
 
 struct ggml_context * ggml_init(struct ggml_init_params params);
@@ -336,6 +343,9 @@ size_t ggml_used_mem(const struct ggml_context * ctx);
 
 size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch);
 
+bool ggml_mlock_supported(void);
+bool ggml_mlock(struct ggml_context * ctx, char ** err_p);
+
 struct ggml_tensor * ggml_new_tensor(
         struct ggml_context * ctx,
         enum   ggml_type type,
@@ -466,12 +476,20 @@ struct ggml_tensor * ggml_gelu(
         struct ggml_context * ctx,
         struct ggml_tensor  * a);
 
+struct ggml_tensor * ggml_silu(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
 // normalize along rows
 // TODO: eps is hardcoded to 1e-5 for now
 struct ggml_tensor * ggml_norm(
         struct ggml_context * ctx,
         struct ggml_tensor  * a);
 
+struct ggml_tensor * ggml_rms_norm(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
 // A: m rows, n columns
 // B: p rows, n columns (i.e. we transpose it internally)
 // result is m columns, p rows
@@ -726,6 +744,13 @@ enum ggml_opt_result ggml_opt(
         struct ggml_opt_params params,
         struct ggml_tensor * f);
 
+//
+// quantization
+//
+
+size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
+size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
+
 //
 // system info
 //
index df2235f3ca470c970de9a15de46f7db377f33ee7..02675ee67072d7668c3c018733cda1ca18927688 100644 (file)
@@ -1,18 +1,23 @@
+// Defines CLOCK_MONOTONIC and asprintf on Linux
+#define _GNU_SOURCE
+
 #include "ggml.h"
 
 #if defined(_MSC_VER) || defined(__MINGW32__)
 #include <malloc.h> // using malloc.h with MSC/MINGW
-#elif !defined(__FreeBSD__)
+#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
 #include <alloca.h>
 #endif
 
 #include <assert.h>
+#include <errno.h>
 #include <time.h>
 #include <math.h>
 #include <stdlib.h>
 #include <string.h>
 #include <stdint.h>
 #include <stdio.h>
+#include <float.h>
 
 // if C99 - static_assert is noop
 // ref: https://stackoverflow.com/a/53923785/4039976
 #else
 // ref: https://github.com/ggerganov/whisper.cpp/issues/168
 #include <windows.h>
-#include <errno.h>
 #endif
 
-// Need this to compile with Visual Studio 2017
-#define restrict __restrict
-
 typedef volatile LONG atomic_int;
 typedef atomic_int atomic_bool;
 
@@ -78,13 +79,38 @@ static int sched_yield (void) {
 typedef void* thread_ret_t;
 #endif
 
+// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
+#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
+#ifndef __FMA__
+#define __FMA__
+#endif
+#ifndef __F16C__
+#define __F16C__
+#endif
+#ifndef __SSE3__
+#define __SSE3__
+#endif
+#endif
+
 #ifdef __HAIKU__
 #define static_assert(cond, msg) _Static_assert(cond, msg)
 #endif
 
+#define GGML_MLOCK_SUPPORT 0
+
+#ifdef __has_include
+    #if __has_include(<sys/mman.h>)
+        #undef GGML_MLOCK_SUPPORT
+        #define GGML_MLOCK_SUPPORT 1
+        #include <sys/mman.h>
+    #endif
+#endif
+
+
 /*#define GGML_PERF*/
 #define GGML_DEBUG 0
 #define GGML_GELU_FP16
+#define GGML_SILU_FP16
 
 #define GGML_SOFT_MAX_UNROLL 4
 #define GGML_VEC_DOT_UNROLL  2
@@ -137,10 +163,10 @@ typedef double ggml_float;
 //
 #include <arm_neon.h>
 
-#define GGML_COMPUTE_FP16_TO_FP32(x) (x)
+#define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x))
 #define GGML_COMPUTE_FP32_TO_FP16(x) (x)
 
-#define GGML_FP16_TO_FP32(x) (x)
+#define GGML_FP16_TO_FP32(x) ((float) (x))
 #define GGML_FP32_TO_FP16(x) (x)
 
 #else
@@ -159,8 +185,46 @@ typedef double ggml_float;
 
 #ifdef __F16C__
 
+#ifdef _MSC_VER
+#define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
+#define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
+#else
 #define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
 #define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
+#endif
+
+#elif defined(__POWER9_VECTOR__)
+
+#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
+#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
+/* the inline asm below is about 12% faster than the lookup method */
+#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
+#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
+
+static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
+    register float f;
+    register double d;
+    __asm__(
+        "mtfprd %0,%2\n"
+        "xscvhpdp %0,%0\n"
+        "frsp %1,%0\n" :
+        /* temp */ "=d"(d),
+        /* out */  "=f"(f):
+        /* in */   "r"(h));
+    return f;
+}
+
+static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
+    register double d;
+    register ggml_fp16_t r;
+    __asm__( /* xscvdphp can work on double or single precision */
+        "xscvdphp %0,%2\n"
+        "mffprd %1,%0\n" :
+        /* temp */ "=d"(d),
+        /* out */  "=r"(r):
+        /* in */   "f"(f));
+    return r;
+}
 
 #else
 
@@ -248,6 +312,9 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
 // precomputed gelu table for f16 (128 KB)
 static ggml_fp16_t table_gelu_f16[1 << 16];
 
+// precomputed silu table for f16 (128 KB)
+static ggml_fp16_t table_silu_f16[1 << 16];
+
 // precomputed exp table for f16 (128 KB)
 static ggml_fp16_t table_exp_f16[1 << 16];
 
@@ -256,6 +323,7 @@ static float table_f32_f16[1 << 16];
 
 // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
 // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
+// This is also true for POWER9.
 #if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16)
 
 inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
@@ -272,7 +340,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
 // note: do not use these inside ggml.c
 // these are meant to be used via the ggml.h API
 float ggml_fp16_to_fp32(ggml_fp16_t x) {
-    return GGML_FP16_TO_FP32(x);
+    return (float) GGML_FP16_TO_FP32(x);
 }
 
 ggml_fp16_t ggml_fp32_to_fp16(float x) {
@@ -351,6 +419,673 @@ int64_t ggml_cycles_per_ms(void) {
 
 static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
 
+//
+// quantization
+//
+
+#define QK 32
+
+// AVX routines provided by GH user Const-me
+// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
+#if __AVX2__ || __AVX512F__
+// Unpack 32 4-bit fields into 32 bytes
+// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
+static inline __m256i bytesFromNibbles( const uint8_t* rsi )
+{
+    // Load 16 bytes from memory
+    __m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
+
+    // Expand bytes into uint16_t values
+    __m256i bytes = _mm256_cvtepu8_epi16( tmp );
+
+    // Unpack values into individual bytes
+    const __m256i lowMask = _mm256_set1_epi8( 0xF );
+    __m256i high = _mm256_andnot_si256( lowMask, bytes );
+    __m256i low = _mm256_and_si256( lowMask, bytes );
+    high = _mm256_slli_epi16( high, 4 );
+    bytes = _mm256_or_si256( low, high );
+    return bytes;
+}
+
+static inline __m128i packNibbles( __m256i bytes )
+{
+    // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
+    const __m256i lowByte = _mm256_set1_epi16( 0xFF );
+    __m256i high = _mm256_andnot_si256( lowByte, bytes );
+    __m256i low = _mm256_and_si256( lowByte, bytes );
+    high = _mm256_srli_epi16( high, 4 );
+    bytes = _mm256_or_si256( low, high );
+
+    // Compress uint16_t lanes into bytes
+    __m128i r0 = _mm256_castsi256_si128( bytes );
+    __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
+    return _mm_packus_epi16( r0, r1 );
+}
+#endif
+
+// method 5
+// blocks of QK elements
+// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
+typedef struct {
+    float   d; // delta
+    uint8_t qs[QK / 2]; // nibbles / quants
+} block_q4_0;
+static_assert(sizeof(block_q4_0) == sizeof(float) + QK / 2, "wrong q4_0 block size/padding");
+
+// method 4
+// blocks of QK elements
+// represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
+typedef struct {
+    float   d;
+    float   m;
+    uint8_t qs[QK / 2]; // nibbles / quants
+} block_q4_1;
+static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding");
+
+// reference implementation for deterministic creation of model files
+static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
+    assert(k % QK == 0);
+    const int nb = k / QK;
+
+    uint8_t pp[QK/2];
+
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+
+        for (int l = 0; l < QK; l++) {
+            const float v = x[i*QK + l];
+            amax = MAX(amax, fabsf(v));
+        }
+
+        const float d = amax / ((1 << 3) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = d;
+
+        for (int l = 0; l < QK; l += 2) {
+            const float v0 = x[i*QK + l + 0]*id;
+            const float v1 = x[i*QK + l + 1]*id;
+
+            const uint8_t vi0 = (int8_t)roundf(v0) + 8;
+            const uint8_t vi1 = (int8_t)roundf(v1) + 8;
+
+            assert(vi0 >= 0 && vi0 < 16);
+            assert(vi1 >= 0 && vi1 < 16);
+
+            pp[l/2] = vi0 | (vi1 << 4);
+        }
+
+        memcpy(y[i].qs, pp, sizeof(pp));
+    }
+}
+
+static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK == 0);
+    const int nb = k / QK;
+
+    block_q4_0 * restrict y = vy;
+
+#if defined(__POWER9_VECTOR__)
+    const vector float v85 = vec_splats(8.5f);
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+
+        vector float srcv [8];
+        vector float asrcv[8];
+        vector float amaxv[8];
+
+        for (int l = 0; l < 8; l++) srcv[l]  = *(vector float *)(x + i*32 + 4*l);
+        for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
+
+        for (int l = 0; l < 4; l++) amaxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]);
+        //for (int l = 0; l < 2; l++) amaxv[4*l] = vec_max(amaxv[4*l], amaxv[4*l+2]);
+        amaxv[0] = vec_max(amaxv[0], amaxv[2]);
+        amaxv[4] = vec_max(amaxv[4], amaxv[6]);
+        //for (int l = 0; l < 1; l++) amaxv[8*l] = vec_max(amaxv[8*l], amaxv[8*l+4]);
+        amaxv[0] = vec_max(amaxv[0], amaxv[4]);
+
+        amax = MAX(
+                MAX(vec_extract(amaxv[0], 0), vec_extract(amaxv[0], 1)),
+                MAX(vec_extract(amaxv[0], 2), vec_extract(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 3) - 1);
+        const float id = d ? 1.0/d : 0.0;
+
+        y[i].d = d;
+
+        const vector float vid = vec_splats(id);
+        uint8_t * restrict pb = y[i].qs;
+        for (int l = 0; l < 8; l++) {
+            const vector float vf  = vec_madd(srcv[l], vid, v85);
+            const vector signed int vi = vec_signed(vf);
+
+            pb[2*l + 0] = vec_extract(vi, 0) | (vec_extract(vi, 1) << 4);
+            pb[2*l + 1] = vec_extract(vi, 2) | (vec_extract(vi, 3) << 4);
+        }
+    }
+#elif __ARM_NEON
+    for (int i = 0; i < nb; i++) {
+        float32x4_t srcv [8];
+        float32x4_t asrcv[8];
+        float32x4_t amaxv[8];
+
+        for (int l = 0; l < 8; l++) srcv[l]  = vld1q_f32(x + i*32 + 4*l);
+        for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
+
+        for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
+        for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
+        for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
+
+        // absolute max
+        const float amax = MAX(
+                MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)),
+                MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 3) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = d;
+
+        for (int l = 0; l < 8; l++) {
+            const float32x4_t v  = vmulq_n_f32(srcv[l], id);
+            const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
+            const int32x4_t   vi = vcvtq_s32_f32(vf);
+
+            y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
+            y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
+        }
+    }
+#elif defined(__AVX2__)
+    for (int i = 0; i < nb; i++) {
+        // Load elements into 4 AVX vectors
+        __m256 v0 = _mm256_loadu_ps( x );
+        __m256 v1 = _mm256_loadu_ps( x + 8 );
+        __m256 v2 = _mm256_loadu_ps( x + 16 );
+        __m256 v3 = _mm256_loadu_ps( x + 24 );
+        x += 32;
+
+        // Compute max(abs(e)) for the block
+        const __m256 signBit = _mm256_set1_ps( -0.0f );
+        __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+
+        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+        max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+        max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+        const float maxScalar = _mm_cvtss_f32( max4 );
+
+        // Quantize these floats
+        const float d = maxScalar / 7.0f;
+        y[i].d = d;
+        const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
+        const __m256 mul = _mm256_set1_ps( id );
+
+        // Apply the multiplier
+        v0 = _mm256_mul_ps( v0, mul );
+        v1 = _mm256_mul_ps( v1, mul );
+        v2 = _mm256_mul_ps( v2, mul );
+        v3 = _mm256_mul_ps( v3, mul );
+
+        // Round to nearest integer
+        v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+        v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+        v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+        v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+        // Convert floats to integers
+        __m256i i0 = _mm256_cvtps_epi32( v0 );
+        __m256i i1 = _mm256_cvtps_epi32( v1 );
+        __m256i i2 = _mm256_cvtps_epi32( v2 );
+        __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+        // Convert int32 to int16
+        i0 = _mm256_packs_epi32( i0, i1 );     // 0, 1, 2, 3,  8, 9, 10, 11,  4, 5, 6, 7, 12, 13, 14, 15
+        i2 = _mm256_packs_epi32( i2, i3 );     // 16, 17, 18, 19,  24, 25, 26, 27,  20, 21, 22, 23, 28, 29, 30, 31
+                                            // Convert int16 to int8
+        i0 = _mm256_packs_epi16( i0, i2 );     // 0, 1, 2, 3,  8, 9, 10, 11,  16, 17, 18, 19,  24, 25, 26, 27,  4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+
+        // We got our precious signed bytes, but the order is now wrong
+        // These AVX2 pack instructions process 16-byte pieces independently
+        // The following instruction is fixing the order
+        const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+        i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+        // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
+        const __m256i off = _mm256_set1_epi8( 8 );
+        i0 = _mm256_add_epi8( i0, off );
+
+        // Compress the vector into 4 bit/value, and store
+        __m128i res = packNibbles( i0 );
+        _mm_storeu_si128( ( __m128i* )y[i].qs, res );
+    }
+#elif defined(__wasm_simd128__)
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+
+        v128_t srcv [8];
+        v128_t asrcv[8];
+        v128_t amaxv[8];
+
+        for (int l = 0; l < 8; l++) srcv[l]  = wasm_v128_load(x + i*32 + 4*l);
+        for (int l = 0; l < 8; l++) asrcv[l] = wasm_f32x4_abs(srcv[l]);
+
+        for (int l = 0; l < 4; l++) amaxv[2*l] = wasm_f32x4_max(asrcv[2*l], asrcv[2*l+1]);
+        for (int l = 0; l < 2; l++) amaxv[4*l] = wasm_f32x4_max(amaxv[4*l], amaxv[4*l+2]);
+        for (int l = 0; l < 1; l++) amaxv[8*l] = wasm_f32x4_max(amaxv[8*l], amaxv[8*l+4]);
+
+        amax = MAX(
+                MAX(wasm_f32x4_extract_lane(amaxv[0], 0), wasm_f32x4_extract_lane(amaxv[0], 1)),
+                MAX(wasm_f32x4_extract_lane(amaxv[0], 2), wasm_f32x4_extract_lane(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 3) - 1);
+        const float id = d ? 1.0/d : 0.0;
+
+        y[i].d = d;
+
+        for (int l = 0; l < 8; l++) {
+            const v128_t v  = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
+            const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
+            const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
+
+            y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
+            y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
+        }
+    }
+#else
+    // scalar
+    quantize_row_q4_0_reference(x, y, k);
+#endif
+}
+
+static void quantize_row_q4_1_reference(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK == 0);
+    const int nb = k / QK;
+
+    block_q4_1 * restrict y = vy;
+
+    uint8_t pp[QK/2];
+
+    for (int i = 0; i < nb; i++) {
+        float min = FLT_MAX;
+        float max = -FLT_MAX;
+
+        for (int l = 0; l < QK; l++) {
+            const float v = x[i*QK + l];
+            if (v < min) min = v;
+            if (v > max) max = v;
+        }
+
+        const float d = (max - min) / ((1 << 4) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = d;
+        y[i].m = min;
+
+        for (int l = 0; l < QK; l += 2) {
+            const float v0 = (x[i*QK + l + 0] - min)*id;
+            const float v1 = (x[i*QK + l + 1] - min)*id;
+
+            const uint8_t vi0 = roundf(v0);
+            const uint8_t vi1 = roundf(v1);
+
+            assert(vi0 >= 0 && vi0 < 16);
+            assert(vi1 >= 0 && vi1 < 16);
+
+            pp[l/2] = vi0 | (vi1 << 4);
+        }
+
+        memcpy(y[i].qs, pp, sizeof(pp));
+    }
+}
+
+static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK == 0);
+
+    const int nb = k / QK;
+
+    block_q4_1 * restrict y = vy;
+
+#if defined(__AVX2__)
+    for (int i = 0; i < nb; i++) {
+        // Load elements into 4 AVX vectors
+        __m256 v0 = _mm256_loadu_ps( x );
+        __m256 v1 = _mm256_loadu_ps( x + 8 );
+        __m256 v2 = _mm256_loadu_ps( x + 16 );
+        __m256 v3 = _mm256_loadu_ps( x + 24 );
+        x += 32;
+
+        // Compute max for the block
+        __m256 vmax;
+        vmax = _mm256_max_ps( v0, v1 );
+        vmax = _mm256_max_ps( vmax, v2 );
+        vmax = _mm256_max_ps( vmax, v3 );
+
+        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( vmax, 1 ), _mm256_castps256_ps128( vmax ) );
+        max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+        max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+        const float maxScalar = _mm_cvtss_f32( max4 );
+
+        // Compute min for the block
+        __m256 vmin;
+        vmin = _mm256_min_ps( v0, v1 );
+        vmin = _mm256_min_ps( vmin, v2 );
+        vmin = _mm256_min_ps( vmin, v3 );
+
+        __m128 min4 = _mm_min_ps( _mm256_extractf128_ps( vmin, 1 ), _mm256_castps256_ps128( vmin ) );
+        min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
+        min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
+        const float minScalar = _mm_cvtss_f32( min4 );
+
+        // Quantize these floats
+        const float d = (maxScalar - minScalar) / ((1 << 4) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].m = minScalar;
+        y[i].d = d;
+
+        // x = (x-min)*id
+        const __m256 mul = _mm256_set1_ps( id );
+        const __m256 off = _mm256_set1_ps( minScalar );
+        v0 = _mm256_mul_ps( _mm256_sub_ps( v0, off ), mul );
+        v1 = _mm256_mul_ps( _mm256_sub_ps( v1, off ), mul );
+        v2 = _mm256_mul_ps( _mm256_sub_ps( v2, off ), mul );
+        v3 = _mm256_mul_ps( _mm256_sub_ps( v3, off ), mul );
+
+        // Round to nearest integer
+        v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+        v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+        v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+        v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+        // Convert floats to integers
+        __m256i i0 = _mm256_cvtps_epi32( v0 );
+        __m256i i1 = _mm256_cvtps_epi32( v1 );
+        __m256i i2 = _mm256_cvtps_epi32( v2 );
+        __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+        // Convert int32 to int16
+        i0 = _mm256_packs_epi32( i0, i1 );     // 0, 1, 2, 3,  8, 9, 10, 11,  4, 5, 6, 7, 12, 13, 14, 15
+        i2 = _mm256_packs_epi32( i2, i3 );     // 16, 17, 18, 19,  24, 25, 26, 27,  20, 21, 22, 23, 28, 29, 30, 31
+                                            // Convert int16 to int8
+        i0 = _mm256_packs_epi16( i0, i2 );     // 0, 1, 2, 3,  8, 9, 10, 11,  16, 17, 18, 19,  24, 25, 26, 27,  4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+
+        // We got our precious signed bytes, but the order is now wrong
+        // These AVX2 pack instructions process 16-byte pieces independently
+        // The following instruction is fixing the order
+        const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+        i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+        // Compress the vector into 4 bit/value, and store
+        __m128i res = packNibbles( i0 );
+        _mm_storeu_si128( ( __m128i* )y[i].qs, res );
+    }
+#elif __ARM_NEON
+    for (int i = 0; i < nb; i++) {
+        float32x4_t srcv[8];
+        float32x4_t minv[8];
+        float32x4_t maxv[8];
+
+        for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
+
+        for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
+        for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]);
+        for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l + 4]);
+
+        for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l + 1]);
+        for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l + 2]);
+        for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l + 4]);
+
+        const float min = vminvq_f32(minv[0]);
+        const float max = vmaxvq_f32(maxv[0]);
+
+        const float d = (max - min) / ((1 << 4) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = d;
+        y[i].m = min;
+
+        const float32x4_t minv0 = vdupq_n_f32(min);
+
+        for (int l = 0; l < 8; l++) {
+            const float32x4_t v  = vmulq_n_f32(vsubq_f32(srcv[l], minv0), id);
+            const int32x4_t   vi = vcvtq_s32_f32(v);
+
+            y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
+            y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
+        }
+    }
+#else
+    // scalar
+    quantize_row_q4_1_reference(x, vy, k);
+#endif
+}
+
+static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
+    assert(k % QK == 0);
+    const int nb = k / QK;
+
+    const block_q4_0 * restrict x = vx;
+
+#if defined(__AVX2__)
+    for (int i = 0; i < nb; i++) {
+        // scale factor
+        const __m256 d_v = _mm256_broadcast_ss(&x[i].d);
+
+        const uint8_t * restrict pp = x[i].qs;
+
+        for (int l = 0; l < QK; l += 32) {
+            // Load 32x4-bit integers into 32x8-bit integers
+            __m256i vx8 = bytesFromNibbles(pp+l/2);
+
+            // Subtract 8 from the integers
+            vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
+
+            // Convert to 16-bit int
+            const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
+            const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1));
+
+            // Convert to 32-bit int -> float 32
+            const __m256 vf[4] = {
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))),
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))),
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))),
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1)))
+            };
+
+            // Scale and store
+            for (int j = 0; j < 4; j++) {
+                const __m256 result = _mm256_mul_ps(vf[j], d_v);
+                _mm256_storeu_ps(y + i * QK + l + j*8, result);
+            }
+        }
+    }
+#elif defined(__ARM_NEON)
+    for (int i = 0; i < nb; i++) {
+        const float32x4_t vd = vdupq_n_f32(x[i].d);
+
+        const uint8_t * restrict pp = x[i].qs;
+
+        for (int l = 0; l < QK; l += 16) {
+            // Load 16x4-bit integers into 8x8-bit integers
+            const uint8x8_t v8 = vld1_u8(pp + l/2);
+
+            // Expand 4-bit qs to 8-bit bytes
+            const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
+            const uint8x8_t v1 = vshr_n_u8(v8, 4);
+
+            // Convert to signed 8-bit integers
+            const int8x8_t vs_0 = vreinterpret_s8_u8(v0);
+            const int8x8_t vs_1 = vreinterpret_s8_u8(v1);
+
+            // Subtract 8 from each byte
+            const int8x8_t vb_0 = vsub_s8(vs_0, vdup_n_s8(8));
+            const int8x8_t vb_1 = vsub_s8(vs_1, vdup_n_s8(8));
+
+            // Interleave and combine
+            const int8x8_t vx_0 = vzip1_s8(vb_0, vb_1);
+            const int8x8_t vx_1 = vzip2_s8(vb_0, vb_1);
+
+            const int8x16_t vq = vcombine_s8(vx_0, vx_1);
+
+            // convert to 2x int16x8_t
+            const int16x8_t vi_0 = vmovl_s8(vget_low_s8 (vq));
+            const int16x8_t vi_1 = vmovl_s8(vget_high_s8(vq));
+
+            // convert to 4x float32x4_t
+            const float32x4_t vf_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_0)));
+            const float32x4_t vf_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0)));
+            const float32x4_t vf_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_1)));
+            const float32x4_t vf_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1)));
+
+            // Multiply by d
+            const float32x4_t r0 = vmulq_f32(vf_0, vd);
+            const float32x4_t r1 = vmulq_f32(vf_1, vd);
+            const float32x4_t r2 = vmulq_f32(vf_2, vd);
+            const float32x4_t r3 = vmulq_f32(vf_3, vd);
+
+            // Store
+            vst1q_f32(y + i*QK + l +  0, r0);
+            vst1q_f32(y + i*QK + l +  4, r1);
+            vst1q_f32(y + i*QK + l +  8, r2);
+            vst1q_f32(y + i*QK + l + 12, r3);
+        }
+    }
+#else
+    // scalar
+    for (int i = 0; i < nb; i++) {
+        const float d = x[i].d;
+
+        const uint8_t * restrict pp = x[i].qs;
+
+        for (int l = 0; l < QK; l += 2) {
+            const uint8_t vi = pp[l/2];
+
+            const int8_t vi0 = vi & 0xf;
+            const int8_t vi1 = vi >> 4;
+
+            const float v0 = (vi0 - 8)*d;
+            const float v1 = (vi1 - 8)*d;
+
+            //printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1);
+
+            y[i*QK + l + 0] = v0;
+            y[i*QK + l + 1] = v1;
+
+            assert(!isnan(y[i*QK + l + 0]));
+            assert(!isnan(y[i*QK + l + 1]));
+        }
+    }
+#endif
+}
+
+static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, int k) {
+    assert(k % QK == 0);
+    const int nb = k / QK;
+
+    const block_q4_1 * restrict x = vx;
+
+#if defined(__AVX2__)
+    for (int i = 0; i < nb; i++) {
+        const __m256 d_v = _mm256_broadcast_ss(&x[i].d);
+        const __m256 d_m = _mm256_broadcast_ss(&x[i].m);
+
+        const uint8_t * restrict pp = x[i].qs;
+
+        for (int l = 0; l < QK; l += 32) {
+            // Load 32x4-bit integers into 32x8-bit integers
+            __m256i vx8 = bytesFromNibbles(pp+l/2);
+
+            // Convert to 16-bit int
+            const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
+            const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1));
+
+            // Convert to 32-bit int -> float 32
+            const __m256 vf[4] = {
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))),
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))),
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))),
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1)))
+            };
+
+            // Scale, add m and store
+            for (int j = 0; j < 4; j++) {
+                const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m);
+                _mm256_storeu_ps(y + i * QK + l + j*8, result);
+            }
+        }
+    }
+#elif defined(__ARM_NEON)
+    for (int i = 0; i < nb; i++) {
+        const float32x4_t vd = vdupq_n_f32(x[i].d);
+        const float32x4_t vm = vdupq_n_f32(x[i].m);
+
+        const uint8_t * restrict pp = x[i].qs;
+
+        for (int l = 0; l < QK; l += 16) {
+            // Load 16x4-bit integers into 8x8-bit integers
+            const uint8x8_t v8 = vld1_u8(pp + l/2);
+
+            // Expand 4-bit qs to 8-bit bytes
+            const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
+            const uint8x8_t v1 = vshr_n_u8(v8, 4);
+
+            // Interleave and combine
+            const uint8x8_t vx_0 = vzip1_u8(v0, v1);
+            const uint8x8_t vx_1 = vzip2_u8(v0, v1);
+
+            const uint8x16_t vq = vcombine_u8(vx_0, vx_1);
+
+            // convert to 2x uint16x8_t
+            const uint16x8_t vi_0 = vmovl_s8(vget_low_u8 (vq));
+            const uint16x8_t vi_1 = vmovl_s8(vget_high_u8(vq));
+
+            // convert to 4x float32x4_t
+            const float32x4_t vf_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_0)));
+            const float32x4_t vf_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_0)));
+            const float32x4_t vf_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_1)));
+            const float32x4_t vf_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_1)));
+
+            // multiply by d and add m
+            const float32x4_t r0 = vmlaq_f32(vm, vf_0, vd);
+            const float32x4_t r1 = vmlaq_f32(vm, vf_1, vd);
+            const float32x4_t r2 = vmlaq_f32(vm, vf_2, vd);
+            const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd);
+
+            // Store
+            vst1q_f32(y + i*QK + l +  0, r0);
+            vst1q_f32(y + i*QK + l +  4, r1);
+            vst1q_f32(y + i*QK + l +  8, r2);
+            vst1q_f32(y + i*QK + l + 12, r3);
+        }
+    }
+#else
+    for (int i = 0; i < nb; i++) {
+        const float d = x[i].d;
+        const float m = x[i].m;
+
+        const uint8_t * restrict pp = x[i].qs;
+
+        for (int l = 0; l < QK; l += 2) {
+            const uint8_t vi = pp[l/2];
+
+            const int8_t vi0 = vi & 0xf;
+            const int8_t vi1 = vi >> 4;
+
+            const float v0 = vi0*d + m;
+            const float v1 = vi1*d + m;
+
+            y[i*QK + l + 0] = v0;
+            y[i*QK + l + 1] = v1;
+
+            assert(!isnan(y[i*QK + l + 0]));
+            assert(!isnan(y[i*QK + l + 1]));
+        }
+    }
+#endif
+}
+
 //
 // simd mappings
 //
@@ -443,7 +1178,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
         }                                                         \
         const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
         const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
-        res = vaddvq_f32(vaddq_f32(t0, t1));                      \
+        res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1));         \
     }
 
     #define GGML_F16_VEC                GGML_F16x8
@@ -538,13 +1273,36 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
 #define GGML_F16_EPR  8
 
 // F16 arithmetic is not supported by AVX, so we use F32 instead
-// we take advantage of the _mm256_cvt intrinsics to convert F16 <-> F32
 
 #define GGML_F32Cx8             __m256
 #define GGML_F32Cx8_ZERO        _mm256_setzero_ps()
 #define GGML_F32Cx8_SET1(x)     _mm256_set1_ps(x)
+
+#if defined(__F16C__)
+// the  _mm256_cvt intrinsics require F16C
 #define GGML_F32Cx8_LOAD(x)     _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
 #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
+#else
+static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
+    float tmp[8];
+
+    for (int i = 0; i < 8; i++)
+        tmp[i] = GGML_FP16_TO_FP32(x[i]);
+
+    return _mm256_loadu_ps(tmp);
+}
+static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
+    float arr[8];
+
+    _mm256_storeu_ps(arr, y);
+
+    for (int i = 0; i < 8; i++)
+        x[i] = GGML_FP16_TO_FP32(arr[i]);
+}
+#define GGML_F32Cx8_LOAD(x)     __avx_f32cx8_load(x)
+#define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y)
+#endif
+
 #define GGML_F32Cx8_FMA         GGML_F32x8_FMA
 #define GGML_F32Cx8_ADD         _mm256_add_ps
 #define GGML_F32Cx8_MUL         _mm256_mul_ps
@@ -856,9 +1614,8 @@ inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, co
 inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]/y[i];   }
 
 inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
-    ggml_float sumf = 0.0;
-
 #ifdef GGML_SIMD
+    float sumf = 0.0f;
     const int np = (n & ~(GGML_F32_STEP - 1));
 
     GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
@@ -884,14 +1641,46 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
     }
 #else
     // scalar
+    ggml_float sumf = 0.0;
     for (int i = 0; i < n; ++i) {
-        sumf += x[i]*y[i];
+        sumf += (ggml_float)(x[i]*y[i]);
     }
 #endif
 
     *s = sumf;
 }
 
+#if __AVX512F__ && QK == 32
+static inline __m512 dot_q4_0_oneblock_avx512(
+    __m512 acc,
+    const block_q4_0 * restrict x,
+    const block_q4_0 * restrict y,
+    int i
+) {
+    // Compute combined scale for the block
+    __m512 d = _mm512_set1_ps( x[i].d * y[i].d );
+
+    __m256i bx = bytesFromNibbles( x[i].qs );
+    __m256i by = bytesFromNibbles( y[i].qs );
+
+    // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+    const __m256i off = _mm256_set1_epi8( 8 );
+    bx = _mm256_sub_epi8( bx, off );
+    by = _mm256_sub_epi8( by, off );
+
+    // Sign-extend 16 signed bytes into int16_t
+    __m512i x32 = _mm512_cvtepi8_epi16( bx );
+    __m512i y32 = _mm512_cvtepi8_epi16( by );
+    // Compute products of int16_t integers, add pairwise
+    __m512i i64 = _mm512_madd_epi16( x32, y32 );
+
+    // Convert int32_t to float
+    __m512 p = _mm512_cvtepi32_ps( i64 );
+    // Apply the scale, and accumulate
+    return _mm512_fmadd_ps( d, p, acc );
+}
+#endif
+
 inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
     ggml_float sumf = 0.0;
 
@@ -917,130 +1706,531 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
 
     // leftovers
     for (int i = np; i < n; ++i) {
-        sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]);
+        sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
     }
 #else
     for (int i = 0; i < n; ++i) {
-        sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]);
+        sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
     }
 #endif
 
     *s = sumf;
 }
 
-// compute GGML_VEC_DOT_UNROLL dot products at once
-// xs - x row stride in bytes
-inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
-    ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
+static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int nb = n / QK;
 
-    ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
+    assert(n % QK == 0);
+    assert(nb % 2 == 0);
 
-    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
-        x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
-    }
+    const block_q4_0 * restrict x = vx;
+    const block_q4_0 * restrict y = vy;
 
-#if defined(GGML_SIMD)
-    const int np = (n & ~(GGML_F16_STEP - 1));
+    ggml_float sumf = 0.0;
 
-    GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
+#if defined(__ARM_NEON)
+    float sum0 = 0.0f;
+    float sum1 = 0.0f;
 
-    GGML_F16_VEC ax[GGML_F16_ARR];
-    GGML_F16_VEC ay[GGML_F16_ARR];
+    for (int i = 0; i < nb; i += 2) {
+        const block_q4_0 * restrict x0 = &x[i + 0];
+        const block_q4_0 * restrict y0 = &y[i + 0];
+        const block_q4_0 * restrict x1 = &x[i + 1];
+        const block_q4_0 * restrict y1 = &y[i + 1];
 
-    for (int i = 0; i < np; i += GGML_F16_STEP) {
-        for (int j = 0; j < GGML_F16_ARR; j++) {
-            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+        const uint8x16_t m4b = vdupq_n_u8(0xf);
+        const int8x16_t  s8b = vdupq_n_s8(0x8);
 
-            for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
-                ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v1_0 = vld1q_u8(y0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+        const uint8x16_t v1_1 = vld1q_u8(y1->qs);
 
-                sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
-            }
-        }
-    }
+        // 4-bit -> 8-bit
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
+        const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
 
-    // reduce sum0..sum3 to sum0
-    for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
-        GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
-    }
+        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
 
-    // leftovers
-    for (int i = np; i < n; ++i) {
-        for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
-            sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]);
-        }
-    }
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
+        const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
+
+        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+        const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
+
+        // sub 8
+        const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
+        const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
+
+        const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
+        const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
+
+        const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
+        const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
+
+        const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
+        const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        // dot product into int16x8_t
+        int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
+        int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
+
+        p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
+        p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
+
+        // scalar
+#if defined(__ARM_FEATURE_QRDMX)
+        sum0 += x0->d * y0->d * vaddvq_s32(p_0);
+        sum1 += x1->d * y1->d * vaddvq_s32(p_1);
 #else
-    for (int i = 0; i < n; ++i) {
-        for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
-            sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]);
-        }
-    }
+        sum0 += x0->d * y0->d * (vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3));
+        sum1 += x1->d * y1->d * (vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3));
 #endif
+#else
+           const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
 
-    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
-        s[i] = sumf[i];
-    }
-}
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
 
-inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
-#if defined(GGML_SIMD)
-    const int np = (n & ~(GGML_F32_STEP - 1));
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
 
-    GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
 
-    GGML_F32_VEC ax[GGML_F32_ARR];
-    GGML_F32_VEC ay[GGML_F32_ARR];
+        const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
+        const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
 
-    for (int i = 0; i < np; i += GGML_F32_STEP) {
-        for (int j = 0; j < GGML_F32_ARR; j++) {
-            ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
-            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
-            ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
+        const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
+        const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
 
-            GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
-        }
-    }
+        const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
+        const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
 
-    // leftovers
-    for (int i = np; i < n; ++i) {
-        y[i] += x[i]*v;
+        // scalar
+#if defined(__ARM_FEATURE_QRDMX)
+        sum0 += x0->d * y0->d * vaddvq_s16(p_0);
+        sum1 += x1->d * y1->d * vaddvq_s16(p_1);
+#else
+        sum0 += x0->d * y0->d * (vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
+        sum1 += x1->d * y1->d * (vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7));
+#endif
+#endif
+    }
+
+    sumf = (ggml_float)(sum0 + sum1);
+#elif defined(__AVX512F__)
+    // Initialize accumulator with zeros
+    __m512 acc0 = _mm512_setzero_ps();
+    __m512 acc1 = _mm512_setzero_ps();
+
+    const int superblock_size = 8;
+    const int superblock_count = nb / superblock_size;
+    const int remainder = nb % superblock_size;
+
+    for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
+        int i = superblock_ix * superblock_size;
+
+        acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+0 );
+        acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+1 );
+        acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+2 );
+        acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+3 );
+        acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+4 );
+        acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+5 );
+        acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+6 );
+        acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+7 );
+    }
+
+    // Remainders
+    for (int i = superblock_count * superblock_size; i < nb; ++i) {
+        acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i );
+    }
+
+    // Horizontal sum of all lanes of the accumulator
+    sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        // Compute combined scale for the block
+        const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
+
+        // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
+        __m256i bx = bytesFromNibbles( x[i].qs );
+        __m256i by = bytesFromNibbles( y[i].qs );
+
+        // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+        const __m256i off = _mm256_set1_epi8( 8 );
+        bx = _mm256_sub_epi8( bx, off );
+        by = _mm256_sub_epi8( by, off );
+
+        // Sign-extend first 16 signed bytes into int16_t
+        __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
+        __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
+        // Compute products of int16_t integers, add pairwise
+        __m256i i32 = _mm256_madd_epi16( x16, y16 );
+
+        // Sign-extend last 16 signed bytes into int16_t vectors
+        x16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
+        y16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
+        // Accumulate products of int16_t integers
+        i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16, y16 ) );
+
+        // Convert int32_t to float
+        __m256 p = _mm256_cvtepi32_ps( i32 );
+        // Apply the scale, and accumulate
+        acc = _mm256_fmadd_ps( d, p, acc );
+    }
+
+    // Return horizontal sum of the acc vector
+    __m128 res = _mm256_extractf128_ps( acc, 1 );
+    res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
+    res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
+    res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
+
+    sumf = _mm_cvtss_f32( res );
+#elif defined(__wasm_simd128__)
+    // wasm simd
+    float sum0 = 0.0f;
+    float sum1 = 0.0f;
+
+    for (int i = 0; i < nb; i += 2) {
+        const block_q4_0 * restrict x0 = &px[i + 0];
+        const block_q4_0 * restrict y0 = &py[i + 0];
+        const block_q4_0 * restrict x1 = &px[i + 1];
+        const block_q4_0 * restrict y1 = &py[i + 1];
+
+        const v128_t m4b = wasm_u8x16_splat(0xf);
+        const v128_t s8b = wasm_i8x16_splat(0x8);
+
+        const v128_t v0_0 = wasm_v128_load(x0.qs);
+        const v128_t v0_1 = wasm_v128_load(y0.qs);
+        const v128_t v1_0 = wasm_v128_load(x1.qs);
+        const v128_t v1_1 = wasm_v128_load(y1.qs);
+
+        // 4-bit -> 8-bit
+        const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
+        const v128_t v1_0l = wasm_v128_and(v1_0, m4b);
+
+        const v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
+        const v128_t v1_0h = wasm_u8x16_shr(v1_0, 4);
+
+        const v128_t v0_1l = wasm_v128_and(v0_1, m4b);
+        const v128_t v1_1l = wasm_v128_and(v1_1, m4b);
+
+        const v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
+        const v128_t v1_1h = wasm_u8x16_shr(v1_1, 4);
+
+        // sub 8
+        const v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
+        const v128_t v1_0ls = wasm_i8x16_sub(v1_0l, s8b);
+
+        const v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
+        const v128_t v1_0hs = wasm_i8x16_sub(v1_0h, s8b);
+
+        const v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
+        const v128_t v1_1ls = wasm_i8x16_sub(v1_1l, s8b);
+
+        const v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
+        const v128_t v1_1hs = wasm_i8x16_sub(v1_1h, s8b);
+
+        // dot product into int16x8_t
+        const v128_t pl0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0ls), wasm_i16x8_extend_low_i8x16(v1_0ls));
+        const v128_t pl0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0ls), wasm_i16x8_extend_high_i8x16(v1_0ls));
+
+        const v128_t ph0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0hs), wasm_i16x8_extend_low_i8x16(v1_0hs));
+        const v128_t ph0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0hs), wasm_i16x8_extend_high_i8x16(v1_0hs));
+
+        const v128_t pl1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1ls), wasm_i16x8_extend_low_i8x16(v1_1ls));
+        const v128_t pl1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1ls), wasm_i16x8_extend_high_i8x16(v1_1ls));
+
+        const v128_t ph1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1hs), wasm_i16x8_extend_low_i8x16(v1_1hs));
+        const v128_t ph1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1hs), wasm_i16x8_extend_high_i8x16(v1_1hs));
+
+        const v128_t pl_0 = wasm_i16x8_add(pl0l, pl0h);
+        const v128_t ph_0 = wasm_i16x8_add(ph0l, ph0h);
+
+        const v128_t pl_1 = wasm_i16x8_add(pl1l, pl1h);
+        const v128_t ph_1 = wasm_i16x8_add(ph1l, ph1h);
+
+        const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0);
+        const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1);
+
+        sum0 += x0->d * y0->d * (
+                wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) +
+                wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) +
+                wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) +
+                wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7));
+        sum1 += x1->d * y1->d * (
+                wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) +
+                wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) +
+                wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) +
+                wasm_i16x8_extract_lane(p_1, 6) + wasm_i16x8_extract_lane(p_1, 7));
     }
+
+    sumf = sum0 + sum1;
 #else
     // scalar
-    for (int i = 0; i < n; ++i) {
-        y[i] += x[i]*v;
+    for (int i = 0; i < nb; i++) {
+        const float d0 = x[i].d;
+        const float d1 = y[i].d;
+
+        const uint8_t * restrict p0 = x[i].qs;
+        const uint8_t * restrict p1 = y[i].qs;
+
+        for (int j = 0; j < QK/2; j++) {
+            const uint8_t v0 = p0[j];
+            const uint8_t v1 = p1[j];
+
+            const float f0 = d0*((int8_t) (v0 & 0xf) - 8);
+            const float f1 = d0*((int8_t) (v0 >> 4)  - 8);
+
+            const float f2 = d1*((int8_t) (v1 & 0xf) - 8);
+            const float f3 = d1*((int8_t) (v1 >> 4)  - 8);
+
+            sumf += f0*f2 + f1*f3;
+        }
+    }
+#endif
+
+    *s = sumf;
+}
+
+static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int nb = n / QK;
+
+    const block_q4_1 * restrict x = vx;
+    const block_q4_1 * restrict y = vy;
+
+    float sumf = 0.0;
+
+#if defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+    // Accumulator for constant offsets
+    float acc_offset = 0.0f;
+
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        const float * d0 = &x[i].d;
+        const float * d1 = &y[i].d;
+
+        const float * m0 = &x[i].m;
+        const float * m1 = &y[i].m;
+
+        const __m256 d0v = _mm256_broadcast_ss( d0 );
+        const __m256 d1v = _mm256_broadcast_ss( d1 );
+        const __m256 m0v = _mm256_broadcast_ss( m0 );
+        const __m256 m1v = _mm256_broadcast_ss( m1 );
+
+        // Compute combined scale for the block
+        const __m256 scale_01 = _mm256_mul_ps( d0v, d1v );
+
+        // Compute cross scales for the block
+        const __m256 scale_0 = _mm256_mul_ps( d0v, m1v );
+        const __m256 scale_1 = _mm256_mul_ps( m0v, d1v );
+        const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ );
+
+        // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
+        __m256i bx = bytesFromNibbles( x[i].qs );
+        __m256i by = bytesFromNibbles( y[i].qs );
+
+        // Now we have a vector with bytes in [ 0 .. 15 ] interval.
+
+        // Sign-extend first 16 signed bytes into int16_t
+        __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
+        __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
+        // Compute products of int16_t integers, add pairwise
+        __m256i i32 = _mm256_madd_epi16( x16, y16 );
+
+        // Sign-extend last 16 signed bytes into int16_t vectors
+        __m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
+        __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
+        // Accumulate products of int16_t integers
+        i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) );
+
+        // compute sums of unsigned bytes in bx, by in blocks of 8.
+        // This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000,
+        // which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400.
+        // so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ]
+        __m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() );
+        __m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() );
+        __m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) );
+        __m256  sums  = _mm256_cvtepi32_ps( sumsi );
+
+        // Convert int32_t to float
+        __m256 p = _mm256_cvtepi32_ps( i32 );
+        // Apply the scale, and accumulate
+        // acc += d0*d1*x*y + d0*m1*x + d1*m0*y
+        acc = _mm256_fmadd_ps( scale_01, p, acc );
+        acc = _mm256_fmadd_ps( cross_scales, sums, acc );
+        // acc_offset += m0*m1 (for each entry in the block)
+        acc_offset += (*m0)*(*m1);
+    }
+
+    // Return horizontal sum of the acc vector
+    __m128 res = _mm256_extractf128_ps( acc, 1 );
+    res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
+    res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
+    res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
+
+    sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
+#elif defined(__ARM_NEON)
+    float sum00 = 0.0f;
+    float sum01 = 0.0f;
+    float sum10 = 0.0f;
+    float sum11 = 0.0f;
+
+    for (int i = 0; i < nb; ++i) {
+        const block_q4_1 * restrict x0 = &x[i + 0];
+        const block_q4_1 * restrict y0 = &y[i + 0];
+
+        const uint8x16_t m4b = vdupq_n_u8(0xf);
+
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v1_0 = vld1q_u8(y0->qs);
+
+        // and with 0xf
+        const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
+        const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
+
+        const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
+        const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
+
+        // dot product into uint16x8_t
+        const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
+        const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
+
+        const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
+        const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
+
+        const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h);
+        const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h);
+
+        sum00 += x0->m*y0->m;
+        sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
+        sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
+        sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0));
+    }
+
+    sumf = QK*sum00 + sum01 + sum10 + sum11;
+#else
+    // scalar
+    for (int i = 0; i < nb; i++) {
+        const float d0 = x[i].d;
+        const float d1 = y[i].d;
+
+        const float m0 = x[i].m;
+        const float m1 = y[i].m;
+
+        const uint8_t * restrict p0 = x[i].qs;
+        const uint8_t * restrict p1 = y[i].qs;
+
+        for (int j = 0; j < QK/2; j++) {
+            const uint8_t v0 = p0[j];
+            const uint8_t v1 = p1[j];
+
+            const float f0 = d0*(v0 & 0xf) + m0;
+            const float f1 = d0*(v0 >> 4)  + m0;
+
+            const float f2 = d1*(v1 & 0xf) + m1;
+            const float f3 = d1*(v1 >> 4)  + m1;
+
+            sumf += f0*f2 + f1*f3;
+        }
     }
 #endif
+
+    *s = sumf;
 }
 
-inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_fp16_t * restrict x, const float v) {
+// compute GGML_VEC_DOT_UNROLL dot products at once
+// xs - x row stride in bytes
+inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
+    ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
+
+    ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
+
+    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
+        x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
+    }
+
 #if defined(GGML_SIMD)
     const int np = (n & ~(GGML_F16_STEP - 1));
 
-    GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
+    GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
 
     GGML_F16_VEC ax[GGML_F16_ARR];
     GGML_F16_VEC ay[GGML_F16_ARR];
 
     for (int i = 0; i < np; i += GGML_F16_STEP) {
         for (int j = 0; j < GGML_F16_ARR; j++) {
-            ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
             ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
-            ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
 
-            GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
+            for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
+                ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
+
+                sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
+            }
+        }
+    }
+
+    // reduce sum0..sum3 to sum0
+    for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
+        GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
+            sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
+        }
+    }
+#else
+    for (int i = 0; i < n; ++i) {
+        for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
+            sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
+        }
+    }
+#endif
+
+    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
+        s[i] = sumf[i];
+    }
+}
+
+inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F32_STEP - 1));
+
+    GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
+
+    GGML_F32_VEC ax[GGML_F32_ARR];
+    GGML_F32_VEC ay[GGML_F32_ARR];
+
+    for (int i = 0; i < np; i += GGML_F32_STEP) {
+        for (int j = 0; j < GGML_F32_ARR; j++) {
+            ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
+            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+            ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
+
+            GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
         }
     }
 
     // leftovers
     for (int i = np; i < n; ++i) {
-        GGML_ASSERT(false);
-        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
+        y[i] += x[i]*v;
     }
 #else
+    // scalar
     for (int i = 0; i < n; ++i) {
-        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
+        y[i] += x[i]*v;
     }
 #endif
 }
@@ -1075,19 +2265,19 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) {
 #endif
 }
 
-inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrt(*s);   }
+inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s);   }
 inline static void ggml_vec_sqr_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i];   }
-inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrt(x[i]); }
+inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
 inline static void ggml_vec_abs_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
 inline static void ggml_vec_sgn_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
 inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
 inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
 
-static const ggml_float GELU_COEF_A    = 0.044715;
-static const ggml_float SQRT_2_OVER_PI = 0.79788456080286535587989211986876;
+static const float GELU_COEF_A    = 0.044715f;
+static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
 
 inline static float ggml_gelu_f32(float x) {
-    return 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x)));
+    return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
 }
 
 inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
@@ -1114,11 +2304,40 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
 }
 #endif
 
+// Sigmoid Linear Unit (SiLU) function
+inline static float ggml_silu_f32(float x) {
+    return x/(1.0f + expf(-x));
+}
+
+inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    const uint16_t * i16 = (const uint16_t *) x;
+    for (int i = 0; i < n; ++i) {
+        y[i] = table_silu_f16[i16[i]];
+    }
+}
+
+#ifdef GGML_SILU_FP16
+inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
+    uint16_t t;
+    for (int i = 0; i < n; ++i) {
+        ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
+        memcpy(&t, &fp16, sizeof(uint16_t));
+        y[i] = GGML_FP16_TO_FP32(table_silu_f16[t]);
+    }
+}
+#else
+inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = ggml_silu_f32(x[i]);
+    }
+}
+#endif
+
 inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
 #ifndef GGML_USE_ACCELERATE
     ggml_float sum = 0.0;
     for (int i = 0; i < n; ++i) {
-        sum += x[i];
+        sum += (ggml_float)x[i];
     }
     *s = sum;
 #else
@@ -1128,7 +2347,7 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
 
 inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
 #ifndef GGML_USE_ACCELERATE
-    ggml_float max = -INFINITY;
+    float max = -INFINITY;
     for (int i = 0; i < n; ++i) {
         max = MAX(max, x[i]);
     }
@@ -1138,7 +2357,10 @@ inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
 #endif
 }
 
-inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { ggml_vec_norm_f32(n, s, x); *s = 1./(*s); }
+inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) {
+    ggml_vec_norm_f32(n, s, x);
+    *s = 1.f/(*s);
+}
 
 //
 // logging
@@ -1168,7 +2390,21 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
 // data types
 //
 
+static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
+    QK,
+    QK,
+    1,
+    1,
+    1,
+    1,
+    1,
+};
+
+static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
+
 static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
+    sizeof(block_q4_0),
+    sizeof(block_q4_1),
     sizeof(int8_t ),
     sizeof(int16_t),
     sizeof(int32_t),
@@ -1176,6 +2412,9 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
     sizeof(float  ),
 };
 
+// don't forget to update the array above when adding new types
+static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
+
 static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "NONE",
 
@@ -1195,7 +2434,9 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "STEP",
     "RELU",
     "GELU",
+    "SILU",
     "NORM",
+    "RMS_NORM",
 
     "MUL_MAT",
 
@@ -1216,6 +2457,8 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "FLASH_FF",
 };
 
+static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
+
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
 
@@ -1235,7 +2478,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "step(x)",
     "relu(x)",
     "gelu(x)",
+    "silu(x)",
     "norm(x)",
+    "rms_norm(x)",
 
     "X*Y",
 
@@ -1256,6 +2501,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "flash_ff(x)",
 };
 
+static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
+
 //
 // ggml object
 //
@@ -1282,6 +2529,7 @@ struct ggml_context {
     size_t mem_size;
     void * mem_buffer;
     bool   mem_buffer_owned;
+    bool   mem_buffer_mlocked;
 
     int n_objects;
 
@@ -1383,13 +2631,21 @@ int ggml_nrows(const struct ggml_tensor * tensor) {
 size_t ggml_nbytes(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
-    return ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type];
+    return (ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type];
+}
+
+int ggml_blck_size(enum ggml_type type) {
+    return GGML_BLCK_SIZE[type];
 }
 
 size_t ggml_type_size(enum ggml_type type) {
     return GGML_TYPE_SIZE[type];
 }
 
+float ggml_type_sizef(enum ggml_type type) {
+    return ((float)(GGML_TYPE_SIZE[type]))/GGML_BLCK_SIZE[type];
+}
+
 size_t ggml_element_size(const struct ggml_tensor * tensor) {
     return GGML_TYPE_SIZE[tensor->type];
 }
@@ -1416,9 +2672,13 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
     return
-        (t0->ne[0]  == t1->ne[0])  &&
-        (t0->ne[2]  == t1->ne[2])  &&
-        (t0->ne[3]  == t1->ne[3]);
+        (t0->ne[0] == t1->ne[0])  &&
+        (t0->ne[2] == t1->ne[2])  &&
+        (t0->ne[3] == t1->ne[3]);
+}
+
+static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
+    return tensor->nb[0] > tensor->nb[1];
 }
 
 static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
@@ -1426,7 +2686,7 @@ static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
 
     return
         tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] &&
-        tensor->nb[1] == tensor->nb[0]*tensor->ne[0] &&
+        tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/GGML_BLCK_SIZE[tensor->type] &&
         tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
         tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
 }
@@ -1477,7 +2737,7 @@ static inline int ggml_up(int n, int m) {
 
 // assert that pointer is aligned to GGML_MEM_ALIGN
 #define ggml_assert_aligned(ptr) \
-    assert(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0)
+    GGML_ASSERT(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0)
 
 ////////////////////////////////////////////////////////////////////////////////
 
@@ -1491,7 +2751,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
         // initialize time system (required on Windows)
         ggml_time_init();
 
-        // initialize GELU, EXP and F32 tables
+        // initialize GELU, SILU and EXP F32 tables
         {
             const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
 
@@ -1501,12 +2761,13 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
                 memcpy(&ii, &ui, sizeof(ii));
                 const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
                 table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
-                table_exp_f16[i]  = GGML_FP32_TO_FP16(exp(f));
+                table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
+                table_exp_f16[i]  = GGML_FP32_TO_FP16(expf(f));
             }
 
             const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
 
-            GGML_PRINT_DEBUG("%s: GELU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
+            GGML_PRINT_DEBUG("%s: GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
         }
 
         // initialize g_state
@@ -1551,16 +2812,19 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
     }
 
     *ctx = (struct ggml_context) {
-        /*.mem_size         =*/ params.mem_size,
-        /*.mem_buffer       =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
-        /*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
-        /*.n_objects        =*/ 0,
-        /*.objects_begin    =*/ NULL,
-        /*.objects_end      =*/ NULL,
-        /*.scratch          =*/ { 0, 0, NULL, },
-        /*.scratch_save     =*/ { 0, 0, NULL, },
+        /*.mem_size           =*/ params.mem_size,
+        /*.mem_buffer         =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
+        /*.mem_buffer_owned   =*/ params.mem_buffer ? false : true,
+        /*.mem_buffer_mlocked =*/ false,
+        /*.n_objects          =*/ 0,
+        /*.objects_begin      =*/ NULL,
+        /*.objects_end        =*/ NULL,
+        /*.scratch            =*/ { 0, 0, NULL, },
+        /*.scratch_save       =*/ { 0, 0, NULL, },
     };
 
+    GGML_ASSERT(ctx->mem_buffer != NULL); // check for allocation failure
+
     ggml_assert_aligned(ctx->mem_buffer);
 
     GGML_PRINT_DEBUG("%s: context initialized\n", __func__);
@@ -1583,6 +2847,14 @@ void ggml_free(struct ggml_context * ctx) {
             GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n",
                     __func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size);
 
+#if GGML_MLOCK_SUPPORT
+            if (ctx->mem_buffer_mlocked) {
+                if (munlock(ctx->mem_buffer, ctx->mem_size)) {
+                    fprintf(stderr, "%s: failed to munlock buffer: %s\n", __func__, strerror(errno));
+                }
+            }
+#endif
+
             if (ctx->mem_buffer_owned) {
                 free(ctx->mem_buffer);
             }
@@ -1611,6 +2883,37 @@ size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch)
     return result;
 }
 
+bool ggml_mlock_supported(void) {
+    return GGML_MLOCK_SUPPORT;
+}
+
+#if GGML_MLOCK_SUPPORT
+#ifdef __APPLE__
+    #define MLOCK_SUGGESTION "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or\n" \
+                             "decreasing 'vm.global_no_user_wire_amount'.  Also try increasing RLIMIT_MLOCK (ulimit -l)."
+#else
+    #define MLOCK_SUGGESTION "Try increasing RLIMIT_MLOCK (ulimit -l)."
+#endif
+bool ggml_mlock(struct ggml_context * ctx, char ** err_p) {
+    if (ctx->mem_buffer_mlocked) {
+        return true;
+    }
+    if (mlock(ctx->mem_buffer, ctx->mem_size)) {
+        int ret = asprintf(err_p, "failed to mlock %zu-byte buffer: %s\n" MLOCK_SUGGESTION,
+                           ctx->mem_size, strerror(errno));
+        GGML_ASSERT(ret >= 0);
+        return false;
+    }
+    ctx->mem_buffer_mlocked = true;
+    return true;
+}
+#else // GGML_MLOCK_SUPPORT
+bool ggml_mlock(struct ggml_context * ctx, char ** err_p) {
+    *err_p = strdup("can't mlock because it's not supported on this system");
+    return false;
+}
+#endif // GGML_MLOCK_SUPPORT
+
 ////////////////////////////////////////////////////////////////////////////////
 
 struct ggml_tensor * ggml_new_tensor_impl(
@@ -1629,8 +2932,8 @@ struct ggml_tensor * ggml_new_tensor_impl(
     size_t size_needed = 0;
 
     if (data == NULL) {
-        size_needed += GGML_TYPE_SIZE[type];
-        for (int i = 0; i < n_dims; i++) {
+        size_needed += GGML_TYPE_SIZE[type]*(ne[0]/GGML_BLCK_SIZE[type]);
+        for (int i = 1; i < n_dims; i++) {
             size_needed *= ne[i];
         }
         // align to GGML_MEM_ALIGN
@@ -1723,7 +3026,8 @@ struct ggml_tensor * ggml_new_tensor_impl(
     }
 
     result->nb[0] = GGML_TYPE_SIZE[type];
-    for (int i = 1; i < GGML_MAX_DIMS; i++) {
+    result->nb[1] = result->nb[0]*(result->ne[0]/GGML_BLCK_SIZE[type]);
+    for (int i = 2; i < GGML_MAX_DIMS; i++) {
         result->nb[i] = result->nb[i - 1]*result->ne[i - 1];
     }
 
@@ -1820,6 +3124,14 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
     char * const data = tensor->data;
 
     switch (tensor->type) {
+        case GGML_TYPE_Q4_0:
+            {
+                GGML_ASSERT(false);
+            } break;
+        case GGML_TYPE_Q4_1:
+            {
+                GGML_ASSERT(false);
+            } break;
         case GGML_TYPE_I8:
             {
                 assert(tensor->nb[0] == sizeof(int8_t));
@@ -1857,7 +3169,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
             } break;
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 
@@ -1872,6 +3184,14 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
     char * const data = tensor->data;
 
     switch (tensor->type) {
+        case GGML_TYPE_Q4_0:
+            {
+                GGML_ASSERT(false);
+            } break;
+        case GGML_TYPE_Q4_1:
+            {
+                GGML_ASSERT(false);
+            } break;
         case GGML_TYPE_I8:
             {
                 assert(tensor->nb[0] == sizeof(int8_t));
@@ -1909,7 +3229,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
             } break;
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 
@@ -1918,6 +3238,14 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
 
 int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
     switch (tensor->type) {
+        case GGML_TYPE_Q4_0:
+            {
+                GGML_ASSERT(false);
+            } break;
+        case GGML_TYPE_Q4_1:
+            {
+                GGML_ASSERT(false);
+            } break;
         case GGML_TYPE_I8:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -1954,6 +3282,14 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
 
 void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
     switch (tensor->type) {
+        case GGML_TYPE_Q4_0:
+            {
+                GGML_ASSERT(false);
+            } break;
+        case GGML_TYPE_Q4_1:
+            {
+                GGML_ASSERT(false);
+            } break;
         case GGML_TYPE_I8:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -1988,6 +3324,14 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
 
 float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
     switch (tensor->type) {
+        case GGML_TYPE_Q4_0:
+            {
+                GGML_ASSERT(false);
+            } break;
+        case GGML_TYPE_Q4_1:
+            {
+                GGML_ASSERT(false);
+            } break;
         case GGML_TYPE_I8:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -2024,6 +3368,14 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
 
 void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
     switch (tensor->type) {
+        case GGML_TYPE_Q4_0:
+            {
+                GGML_ASSERT(false);
+            } break;
+        case GGML_TYPE_Q4_1:
+            {
+                GGML_ASSERT(false);
+            } break;
         case GGML_TYPE_I8:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -2114,7 +3466,7 @@ struct ggml_tensor * ggml_add_impl(
         struct ggml_tensor * a,
         struct ggml_tensor * b,
         bool inplace) {
-    assert(ggml_are_same_shape(a, b));
+    GGML_ASSERT(ggml_are_same_shape(a, b));
 
     bool is_node = false;
 
@@ -2153,7 +3505,7 @@ struct ggml_tensor * ggml_sub_impl(
         struct ggml_tensor * a,
         struct ggml_tensor * b,
         bool inplace) {
-    assert(ggml_are_same_shape(a, b));
+    GGML_ASSERT(ggml_are_same_shape(a, b));
 
     bool is_node = false;
 
@@ -2192,7 +3544,7 @@ struct ggml_tensor * ggml_mul_impl(
         struct ggml_tensor * a,
         struct ggml_tensor * b,
         bool inplace) {
-    assert(ggml_are_same_shape(a, b));
+    GGML_ASSERT(ggml_are_same_shape(a, b));
 
     bool is_node = false;
 
@@ -2201,7 +3553,7 @@ struct ggml_tensor * ggml_mul_impl(
     }
 
     if (inplace) {
-        assert(is_node == false);
+        GGML_ASSERT(is_node == false);
     }
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
@@ -2235,7 +3587,7 @@ struct ggml_tensor * ggml_div_impl(
         struct ggml_tensor * a,
         struct ggml_tensor * b,
         bool inplace) {
-    assert(ggml_are_same_shape(a, b));
+    GGML_ASSERT(ggml_are_same_shape(a, b));
 
     bool is_node = false;
 
@@ -2244,7 +3596,7 @@ struct ggml_tensor * ggml_div_impl(
     }
 
     if (inplace) {
-        assert(is_node == false);
+        GGML_ASSERT(is_node == false);
     }
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
@@ -2368,7 +3720,7 @@ struct ggml_tensor * ggml_mean(
     bool is_node = false;
 
     if (a->grad) {
-        assert(false); // TODO: implement
+        GGML_ASSERT(false); // TODO: implement
         is_node = true;
     }
 
@@ -2389,7 +3741,7 @@ struct ggml_tensor * ggml_repeat(
         struct ggml_context * ctx,
         struct ggml_tensor * a,
         struct ggml_tensor * b) {
-    assert(ggml_can_repeat(a, b));
+    GGML_ASSERT(ggml_can_repeat(a, b));
 
     bool is_node = false;
 
@@ -2616,80 +3968,148 @@ struct ggml_tensor * ggml_gelu_inplace(
     return ggml_gelu_impl(ctx, a, true);
 }
 
-// ggml_norm
+// ggml_silu
 
-struct ggml_tensor * ggml_norm_impl(
+struct ggml_tensor * ggml_silu_impl(
         struct ggml_context * ctx,
-        struct ggml_tensor  * a,
+        struct ggml_tensor * a,
         bool inplace) {
     bool is_node = false;
 
     if (!inplace && (a->grad)) {
-        assert(false); // TODO: implement backward
         is_node = true;
     }
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    result->op   = GGML_OP_NORM;
+    result->op   = GGML_OP_SILU;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
-    result->src1 = NULL; // TODO: maybe store epsilon here?
+    result->src1 = NULL;
 
     return result;
 }
 
-struct ggml_tensor * ggml_norm(
+struct ggml_tensor * ggml_silu(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_norm_impl(ctx, a, false);
+    return ggml_silu_impl(ctx, a, false);
 }
 
-struct ggml_tensor * ggml_norm_inplace(
+struct ggml_tensor * ggml_silu_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_norm_impl(ctx, a, true);
+    return ggml_silu_impl(ctx, a, true);
 }
 
-// ggml_mul_mat
+// ggml_norm
 
-struct ggml_tensor * ggml_mul_mat(
+struct ggml_tensor * ggml_norm_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        struct ggml_tensor  * b) {
-    assert(ggml_can_mul_mat(a, b));
-
+        bool inplace) {
     bool is_node = false;
 
-    if (a->grad || b->grad) {
+    if (!inplace && (a->grad)) {
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
-    const int ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] };
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    result->op   = GGML_OP_MUL_MAT;
+    result->op   = GGML_OP_NORM;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
-    result->src1 = b;
+    result->src1 = NULL; // TODO: maybe store epsilon here?
 
     return result;
 }
 
-// ggml_scale
-
-struct ggml_tensor * ggml_scale_impl(
+struct ggml_tensor * ggml_norm(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_norm_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_norm_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_norm_impl(ctx, a, true);
+}
+
+struct ggml_tensor * ggml_rms_norm_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_RMS_NORM;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL; // TODO: maybe store epsilon here?
+
+    return result;
+}
+
+struct ggml_tensor * ggml_rms_norm(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_rms_norm_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_rms_norm_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_rms_norm_impl(ctx, a, true);
+}
+
+// ggml_mul_mat
+
+struct ggml_tensor * ggml_mul_mat(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    GGML_ASSERT(ggml_can_mul_mat(a, b));
+    GGML_ASSERT(!ggml_is_transposed(a));
+
+    bool is_node = false;
+
+    if (a->grad || b->grad) {
+        is_node = true;
+    }
+
+    const int ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
+
+    result->op   = GGML_OP_MUL_MAT;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+// ggml_scale
+
+struct ggml_tensor * ggml_scale_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
         bool inplace) {
-    assert(ggml_is_scalar(b));
-    assert(ggml_is_padded_1d(a));
+    GGML_ASSERT(ggml_is_scalar(b));
+    GGML_ASSERT(ggml_is_padded_1d(a));
 
     bool is_node = false;
 
     if (!inplace && (a->grad || b->grad)) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -2726,12 +4146,12 @@ struct ggml_tensor * ggml_cpy_impl(
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
         bool inplace) {
-    assert(ggml_nelements(a) == ggml_nelements(b));
+    GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
 
     bool is_node = false;
 
     if (!inplace && (a->grad || b->grad)) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -2766,14 +4186,14 @@ struct ggml_tensor * ggml_reshape(
         struct ggml_context * ctx,
         struct ggml_tensor * a,
         struct ggml_tensor * b) {
-    assert(ggml_is_contiguous(a));
-    assert(ggml_is_contiguous(b));
-    assert(ggml_nelements(a) == ggml_nelements(b));
+    GGML_ASSERT(ggml_is_contiguous(a));
+    GGML_ASSERT(ggml_is_contiguous(b));
+    GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
 
     bool is_node = false;
 
     if (a->grad || b->grad) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -2792,13 +4212,13 @@ struct ggml_tensor * ggml_reshape_2d(
         struct ggml_tensor  * a,
         int                   ne0,
         int                   ne1) {
-    assert(ggml_is_contiguous(a));
-    assert(ggml_nelements(a) == ne0*ne1);
+    GGML_ASSERT(ggml_is_contiguous(a));
+    GGML_ASSERT(ggml_nelements(a) == ne0*ne1);
 
     bool is_node = false;
 
     if (a->grad) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -2819,13 +4239,13 @@ struct ggml_tensor * ggml_reshape_3d(
         int                   ne0,
         int                   ne1,
         int                   ne2) {
-    assert(ggml_is_contiguous(a));
-    assert(ggml_nelements(a) == ne0*ne1*ne2);
+    GGML_ASSERT(ggml_is_contiguous(a));
+    GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2);
 
     bool is_node = false;
 
     if (a->grad) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -2848,7 +4268,7 @@ struct ggml_tensor * ggml_view_1d(
         int                   ne0,
         size_t                offset) {
     if (a->grad) {
-        assert(false); // gradient propagation is not supported
+        GGML_ASSERT(false); // gradient propagation is not supported
     }
 
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset);
@@ -2871,7 +4291,7 @@ struct ggml_tensor * ggml_view_2d(
         size_t                nb1,
         size_t                offset) {
     if (a->grad) {
-        assert(false); // gradient propagation is not supported
+        GGML_ASSERT(false); // gradient propagation is not supported
     }
 
     const int ne[GGML_MAX_DIMS] = { ne0, ne1, 1, 1 };
@@ -2899,22 +4319,22 @@ struct ggml_tensor * ggml_permute(
         int                   axis1,
         int                   axis2,
         int                   axis3) {
-    assert(axis0 >= 0 && axis0 < GGML_MAX_DIMS);
-    assert(axis1 >= 0 && axis1 < GGML_MAX_DIMS);
-    assert(axis2 >= 0 && axis2 < GGML_MAX_DIMS);
-    assert(axis3 >= 0 && axis3 < GGML_MAX_DIMS);
-
-    assert(axis0 != axis1);
-    assert(axis0 != axis2);
-    assert(axis0 != axis3);
-    assert(axis1 != axis2);
-    assert(axis1 != axis3);
-    assert(axis2 != axis3);
+    GGML_ASSERT(axis0 >= 0 && axis0 < GGML_MAX_DIMS);
+    GGML_ASSERT(axis1 >= 0 && axis1 < GGML_MAX_DIMS);
+    GGML_ASSERT(axis2 >= 0 && axis2 < GGML_MAX_DIMS);
+    GGML_ASSERT(axis3 >= 0 && axis3 < GGML_MAX_DIMS);
+
+    GGML_ASSERT(axis0 != axis1);
+    GGML_ASSERT(axis0 != axis2);
+    GGML_ASSERT(axis0 != axis3);
+    GGML_ASSERT(axis1 != axis2);
+    GGML_ASSERT(axis1 != axis3);
+    GGML_ASSERT(axis2 != axis3);
 
     bool is_node = false;
 
     if (a->grad) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -2959,7 +4379,7 @@ struct ggml_tensor * ggml_transpose(
     bool is_node = false;
 
     if (a->grad) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -2985,12 +4405,12 @@ struct ggml_tensor * ggml_get_rows(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b) {
-    assert(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32);
 
     bool is_node = false;
 
     if (a->grad || b->grad) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -3015,7 +4435,7 @@ struct ggml_tensor * ggml_diag_mask_inf(
     bool is_node = false;
 
     if (a->grad) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -3040,7 +4460,7 @@ struct ggml_tensor * ggml_soft_max(
     bool is_node = false;
 
     if (a->grad) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -3064,11 +4484,11 @@ struct ggml_tensor * ggml_rope(
         int                   n_past,
         int                   n_dims,
         int                   mode) {
-    assert(n_past >= 0);
+    GGML_ASSERT(n_past >= 0);
     bool is_node = false;
 
     if (a->grad) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -3095,13 +4515,13 @@ struct ggml_tensor * ggml_conv_1d_1s(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b) {
-    assert(ggml_is_matrix(b));
-    assert(a->ne[1] == b->ne[1]);
-    assert(a->ne[3] == 1);
+    GGML_ASSERT(ggml_is_matrix(b));
+    GGML_ASSERT(a->ne[1] == b->ne[1]);
+    GGML_ASSERT(a->ne[3] == 1);
     bool is_node = false;
 
     if (a->grad || b->grad) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -3122,13 +4542,13 @@ struct ggml_tensor * ggml_conv_1d_2s(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b) {
-    assert(ggml_is_matrix(b));
-    assert(a->ne[1] == b->ne[1]);
-    assert(a->ne[3] == 1);
+    GGML_ASSERT(ggml_is_matrix(b));
+    GGML_ASSERT(a->ne[1] == b->ne[1]);
+    GGML_ASSERT(a->ne[3] == 1);
     bool is_node = false;
 
     if (a->grad || b->grad) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -3151,7 +4571,7 @@ struct ggml_tensor * ggml_flash_attn(
         struct ggml_tensor  * k,
         struct ggml_tensor  * v,
         bool                  masked) {
-    assert(ggml_can_mul_mat(k, q));
+    GGML_ASSERT(ggml_can_mul_mat(k, q));
     // TODO: check if vT can be multiplied by (k*qT)
 
     bool is_node = false;
@@ -3183,7 +4603,7 @@ struct ggml_tensor * ggml_flash_ff(
         struct ggml_tensor  * b1,
         struct ggml_tensor  * c0,
         struct ggml_tensor  * c1) {
-    assert(ggml_can_mul_mat(b0, a));
+    GGML_ASSERT(ggml_can_mul_mat(b0, a));
     // TODO: more checks
 
     bool is_node = false;
@@ -3214,7 +4634,7 @@ void ggml_set_param(
         struct ggml_tensor * tensor) {
     tensor->is_param = true;
 
-    assert(tensor->grad == NULL);
+    GGML_ASSERT(tensor->grad == NULL);
     tensor->grad = ggml_dup_tensor(ctx, tensor);
 }
 
@@ -3224,9 +4644,9 @@ static void ggml_compute_forward_dup_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
-    assert(params->ith == 0);
-    assert(ggml_is_contiguous(dst));
-    assert(ggml_nelements(dst) == ggml_nelements(src0));
+    GGML_ASSERT(params->ith == 0);
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -3249,7 +4669,7 @@ static void ggml_compute_forward_dup_f16(
 
     if (src0->nb[0] == sizeof(ggml_fp16_t)) {
         if (dst->type == GGML_TYPE_F16) {
-            int id = 0;
+            size_t id = 0;
             const size_t rs = ne00*nb00;
 
             for (int i03 = 0; i03 < ne03; i03++) {
@@ -3265,7 +4685,7 @@ static void ggml_compute_forward_dup_f16(
                 }
             }
         } else if (dst->type == GGML_TYPE_F32) {
-            int id = 0;
+            size_t id = 0;
             float * dst_ptr = (float *) dst->data;
 
             for (int i03 = 0; i03 < ne03; i03++) {
@@ -3287,7 +4707,7 @@ static void ggml_compute_forward_dup_f16(
         //printf("%s: this is not optimal - fix me\n", __func__);
 
         if (dst->type == GGML_TYPE_F32) {
-            int id = 0;
+            size_t id = 0;
             float * dst_ptr = (float *) dst->data;
 
             for (int i03 = 0; i03 < ne03; i03++) {
@@ -3303,7 +4723,7 @@ static void ggml_compute_forward_dup_f16(
                 }
             }
         } else if (dst->type == GGML_TYPE_F16) {
-            int id = 0;
+            size_t id = 0;
             ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
 
             for (int i03 = 0; i03 < ne03; i03++) {
@@ -3353,7 +4773,7 @@ static void ggml_compute_forward_dup_f32(
 
     if (src0->nb[0] == sizeof(float)) {
         if (dst->type == GGML_TYPE_F32) {
-            int id = 0;
+            size_t id = 0;
             const size_t rs = ne00*nb00;
 
             for (int i03 = 0; i03 < ne03; i03++) {
@@ -3369,7 +4789,7 @@ static void ggml_compute_forward_dup_f32(
                 }
             }
         } else if (dst->type == GGML_TYPE_F16) {
-            int id = 0;
+            size_t id = 0;
             ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
 
             for (int i03 = 0; i03 < ne03; i03++) {
@@ -3391,7 +4811,7 @@ static void ggml_compute_forward_dup_f32(
         //printf("%s: this is not optimal - fix me\n", __func__);
 
         if (dst->type == GGML_TYPE_F32) {
-            int id = 0;
+            size_t id = 0;
             float * dst_ptr = (float *) dst->data;
 
             for (int i03 = 0; i03 < ne03; i03++) {
@@ -3407,7 +4827,7 @@ static void ggml_compute_forward_dup_f32(
                 }
             }
         } else if (dst->type == GGML_TYPE_F16) {
-            int id = 0;
+            size_t id = 0;
             ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
 
             for (int i03 = 0; i03 < ne03; i03++) {
@@ -3441,6 +4861,8 @@ static void ggml_compute_forward_dup(
             {
                 ggml_compute_forward_dup_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
@@ -3516,13 +4938,15 @@ static void ggml_compute_forward_add(
             {
                 ggml_compute_forward_add_f32(params, src0, src1, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -3566,13 +4990,15 @@ static void ggml_compute_forward_sub(
             {
                 ggml_compute_forward_sub_f32(params, src0, src1, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -3616,13 +5042,15 @@ static void ggml_compute_forward_mul(
             {
                 ggml_compute_forward_mul_f32(params, src0, src1, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -3666,13 +5094,15 @@ static void ggml_compute_forward_div(
             {
                 ggml_compute_forward_div_f32(params, src0, src1, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -3712,13 +5142,15 @@ static void ggml_compute_forward_sqr(
             {
                 ggml_compute_forward_sqr_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -3758,13 +5190,15 @@ static void ggml_compute_forward_sqrt(
             {
                 ggml_compute_forward_sqrt_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -3814,13 +5248,15 @@ static void ggml_compute_forward_sum(
             {
                 ggml_compute_forward_sum_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -3889,13 +5325,15 @@ static void ggml_compute_forward_mean(
             {
                 ggml_compute_forward_mean_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -3951,13 +5389,15 @@ static void ggml_compute_forward_repeat(
             {
                 ggml_compute_forward_repeat_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -3997,13 +5437,15 @@ static void ggml_compute_forward_abs(
             {
                 ggml_compute_forward_abs_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -4043,13 +5485,15 @@ static void ggml_compute_forward_sgn(
             {
                 ggml_compute_forward_sgn_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -4089,13 +5533,15 @@ static void ggml_compute_forward_neg(
             {
                 ggml_compute_forward_neg_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -4135,13 +5581,15 @@ static void ggml_compute_forward_step(
             {
                 ggml_compute_forward_step_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -4181,13 +5629,15 @@ static void ggml_compute_forward_relu(
             {
                 ggml_compute_forward_relu_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -4244,17 +5694,87 @@ static void ggml_compute_forward_gelu(
             {
                 ggml_compute_forward_gelu_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+
+    //printf("XXXXXXXX gelu\n");
+}
+
+// ggml_compute_forward_silu
+
+static void ggml_compute_forward_silu_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_silu_f32(nc,
+                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (float *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_silu(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_silu_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
 
+
 // ggml_compute_forward_norm
 
 static void ggml_compute_forward_norm_f32(
@@ -4285,7 +5805,7 @@ static void ggml_compute_forward_norm_f32(
     const size_t nb2 = dst->nb[2];
     const size_t nb3 = dst->nb[3];
 
-    const ggml_float eps = 1e-5f; // TODO: make this a parameter
+    const float eps = 1e-5f; // TODO: make this a parameter
 
     // TODO: optimize
     for (int i03 = 0; i03 < ne03; i03++) {
@@ -4293,23 +5813,24 @@ static void ggml_compute_forward_norm_f32(
             for (int i01 = ith; i01 < ne01; i01 += nth) {
                 const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
 
-                ggml_float mean = 0.0;
+                ggml_float sum = 0.0;
                 for (int i00 = 0; i00 < ne00; i00++) {
-                    mean += x[i00];
+                    sum += (ggml_float)x[i00];
                 }
 
-                mean /= ne00;
+                float mean = sum/ne00;
 
                 float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
 
                 ggml_float sum2 = 0.0;
                 for (int i00 = 0; i00 < ne00; i00++) {
-                    ggml_float v = x[i00] - mean;
+                    float v = x[i00] - mean;
                     y[i00] = v;
-                    sum2 += v*v;
+                    sum2 += (ggml_float)(v*v);
                 }
 
-                const float scale = 1.0/sqrt(sum2/ne00 + eps);
+                float variance = sum2/ne00;
+                const float scale = 1.0f/sqrtf(variance + eps);
 
                 ggml_vec_scale_f32(ne00, y, scale);
             }
@@ -4326,17 +5847,100 @@ static void ggml_compute_forward_norm(
             {
                 ggml_compute_forward_norm_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+static void ggml_compute_forward_rms_norm_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
+
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
+
+    const float eps = 1e-6f; // TODO: make this a parameter
+
+    // TODO: optimize
+    for (int i03 = 0; i03 < ne03; i03++) {
+        for (int i02 = 0; i02 < ne02; i02++) {
+            for (int i01 = ith; i01 < ne01; i01 += nth) {
+                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+                ggml_float sum = 0.0;
+                for (int i00 = 0; i00 < ne00; i00++) {
+                    sum += (ggml_float)(x[i00] * x[i00]);
+                }
+
+                float mean = sum/ne00;
+
+                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+                memcpy(y, x, ne00 * sizeof(float));
+                // for (int i00 = 0; i00 < ne00; i00++) {
+                //     y[i00] = x[i00];
+                // }
+
+                const float scale = 1.0f/sqrtf(mean + eps);
+
+                ggml_vec_scale_f32(ne00, y, scale);
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_rms_norm(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_rms_norm_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
 
+
 // ggml_compute_forward_mul_mat
 
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
@@ -4346,7 +5950,8 @@ static bool ggml_compute_forward_mul_mat_use_blas(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
               struct ggml_tensor * dst) {
-    UNUSED(src0);
+    //const int ne00 = src0->ne[0];
+    //const int ne01 = src0->ne[1];
 
     const int ne10 = src1->ne[0];
 
@@ -4354,10 +5959,10 @@ static bool ggml_compute_forward_mul_mat_use_blas(
     const int ne1 = dst->ne[1];
 
     // TODO: find the optimal values for these
-    if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && (
-             (ne0 >= 32 && ne1  >= 32   && ne10 >= 32)
-            )) {
-        //printf("BLAS: %d %d %d\n", ne0, ne1, ne10);
+    if (ggml_is_contiguous(src0) &&
+        ggml_is_contiguous(src1) && ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) {
+
+        /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
         return true;
     }
 
@@ -4378,8 +5983,11 @@ static void ggml_compute_forward_mul_mat_f32(
     const int ne02 = src0->ne[2];
     const int ne03 = src0->ne[3];
 
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
     const int ne10 = src1->ne[0];
+#endif
     const int ne11 = src1->ne[1];
+#ifndef NDEBUG
     const int ne12 = src1->ne[2];
     const int ne13 = src1->ne[3];
 
@@ -4387,14 +5995,16 @@ static void ggml_compute_forward_mul_mat_f32(
     const int ne1  = dst->ne[1];
     const int ne2  = dst->ne[2];
     const int ne3  = dst->ne[3];
-    const int ne   = ne0*ne1*ne2*ne3;
 
     const int nb00 = src0->nb[0];
+#endif
     const int nb01 = src0->nb[1];
     const int nb02 = src0->nb[2];
     const int nb03 = src0->nb[3];
 
+#ifndef NDEBUG
     const int nb10 = src1->nb[0];
+#endif
     const int nb11 = src1->nb[1];
     const int nb12 = src1->nb[2];
     const int nb13 = src1->nb[3];
@@ -4412,8 +6022,9 @@ static void ggml_compute_forward_mul_mat_f32(
     assert(ne2  == ne12);
     assert(ne3  == ne13);
 
-    // TODO: we don't support permuted src0
-    assert(nb00 == sizeof(float) || nb01 == sizeof(float));
+    // we don't support permuted src0 or src1
+    assert(nb00 == sizeof(float));
+    assert(nb10 == sizeof(float));
 
     // dst cannot be transposed or permuted
     assert(nb0 == sizeof(float));
@@ -4428,14 +6039,9 @@ static void ggml_compute_forward_mul_mat_f32(
 
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
-    //
-    // nb00 <  nb01 - src0 is transposed
-    //   compute by src0 columns
 
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
-        GGML_ASSERT(nb10 == sizeof(float));
-
         if (params->ith != 0) {
             return;
         }
@@ -4450,19 +6056,17 @@ static void ggml_compute_forward_mul_mat_f32(
 
         for (int i03 = 0; i03 < ne03; i03++) {
             for (int i02 = 0; i02 < ne02; i02++) {
-                const float * x = (float *) (src0->data);
+                const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
 
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
                 // zT = y * xT
-                {
-                    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
-                            ne11, ne01, ne10,
-                            1.0f,    y, ne10,
-                                     x, ne10,
-                            0.0f,    d, ne01);
-                }
+                cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+                        ne11, ne01, ne10,
+                        1.0f,    y, ne10,
+                                 x, ne10,
+                        0.0f,    d, ne01);
             }
         }
 
@@ -4473,130 +6077,242 @@ static void ggml_compute_forward_mul_mat_f32(
 #endif
 
     if (params->type == GGML_TASK_INIT) {
-        if (nb01 >= nb00) {
-            return;
-        }
-
-        // TODO: fix this memset (wsize is overestimated)
-        memset(params->wdata, 0, params->wsize);
         return;
     }
 
     if (params->type == GGML_TASK_FINALIZE) {
-        if (nb01 >= nb00) {
+        return;
+    }
+
+    // parallelize by src0 rows using ggml_vec_dot_f32
+
+    // total rows in src0
+    const int nr = ne01*ne02*ne03;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 indices
+        const int i03 = ir/(ne02*ne01);
+        const int i02 = (ir - i03*ne02*ne01)/ne01;
+        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+        for (int ic = 0; ic < ne11; ++ic) {
+            // src1 indices
+            const int i13 = i03;
+            const int i12 = i02;
+            const int i11 = ic;
+
+            // dst indices
+            const int i0 = i01;
+            const int i1 = i11;
+            const int i2 = i02;
+            const int i3 = i03;
+
+            ggml_vec_dot_f32(ne00,
+                    (float *) ((char *)  dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
+                    (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)),
+                    (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)));
+        }
+    }
+
+    //int64_t t1 = ggml_perf_time_us();
+    //static int64_t acc = 0;
+    //acc += t1 - t0;
+    //if (t1 - t0 > 10) {
+    //    printf("\n");
+    //    printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
+    //    printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
+    //    printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
+    //    printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
+
+    //    printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
+    //}
+}
+
+static void ggml_compute_forward_mul_mat_f16_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
+
+    const int ne10 = src1->ne[0];
+    const int ne11 = src1->ne[1];
+    const int ne12 = src1->ne[2];
+    const int ne13 = src1->ne[3];
+
+    const int ne0  = dst->ne[0];
+    const int ne1  = dst->ne[1];
+    const int ne2  = dst->ne[2];
+    const int ne3  = dst->ne[3];
+    //const int ne   = ne0*ne1*ne2*ne3;
+
+    const int nb00 = src0->nb[0];
+    const int nb01 = src0->nb[1];
+    const int nb02 = src0->nb[2];
+    const int nb03 = src0->nb[3];
+
+    const int nb10 = src1->nb[0];
+    const int nb11 = src1->nb[1];
+    const int nb12 = src1->nb[2];
+    const int nb13 = src1->nb[3];
+
+    const int nb0  = dst->nb[0];
+    const int nb1  = dst->nb[1];
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_ASSERT(ne02 == ne12);
+    GGML_ASSERT(ne03 == ne13);
+    GGML_ASSERT(ne2  == ne12);
+    GGML_ASSERT(ne3  == ne13);
+
+    // TODO: we don't support permuted src0
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    GGML_ASSERT(ne0 == ne01);
+    GGML_ASSERT(ne1 == ne11);
+    GGML_ASSERT(ne2 == ne02);
+    GGML_ASSERT(ne3 == ne03);
+
+    // nb01 >= nb00 - src0 is not transposed
+    //   compute by src0 rows
+
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+    if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
+        GGML_ASSERT(nb10 == sizeof(float));
+
+        if (params->ith != 0) {
             return;
         }
 
-        // TODO: fix this memset (wsize is overestimated)
-        //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth);
+        if (params->type == GGML_TASK_INIT) {
+            return;
+        }
+
+        if (params->type == GGML_TASK_FINALIZE) {
+            return;
+        }
 
         float * const wdata = params->wdata;
 
-        // cols per thread
-        const int dc = (ne + nth - 1)/nth;
+        for (int i03 = 0; i03 < ne03; i03++) {
+            for (int i02 = 0; i02 < ne02; i02++) {
+                {
+                    size_t id = 0;
+                    for (int i01 = 0; i01 < ne01; ++i01) {
+                        for (int i00 = 0; i00 < ne00; ++i00) {
+                            wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
+                        }
+                    }
+                }
+
+                const float * x = wdata;
+                const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
+
+                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+
+                // zT = y * xT
+                cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+                        ne11, ne01, ne10,
+                        1.0f,    y, ne10,
+                                 x, ne10,
+                        0.0f,    d, ne01);
+            }
+        }
+
+        /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
 
-        // col range for this thread
-        const int ic0 = dc*ith;
-        const int ic1 = MIN(ic0 + dc, ne);
+        return;
+    }
+#endif
 
-        ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0);
+    if (params->type == GGML_TASK_INIT) {
+        ggml_fp16_t * const wdata = params->wdata;
 
-        for (int k = 1; k < nth; k++) {
-            ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0);
+        size_t id = 0;
+        for (int i13 = 0; i13 < ne13; ++i13) {
+            for (int i12 = 0; i12 < ne12; ++i12) {
+                for (int i11 = 0; i11 < ne11; ++i11) {
+                    for (int i10 = 0; i10 < ne10; ++i10) {
+                        wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
+                    }
+                }
+            }
         }
 
+        GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize);
+
         return;
     }
 
-    if (nb01 >= nb00) {
-        // TODO: do not support transposed src1
-        assert(nb10 == sizeof(float));
+    if (params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
 
-        // parallelize by src0 rows using ggml_vec_dot_f32
+    // fp16 -> half the size, so divide by 2
+    // TODO: do not support transposed src1
+    assert(nb10/2 == sizeof(ggml_fp16_t));
 
-        // total rows in src0
-        const int nr = ne01*ne02*ne03;
+    // parallelize by src0 rows using ggml_vec_dot_f16
 
-        // rows per thread
-        const int dr = (nr + nth - 1)/nth;
+    // total rows in src0
+    const int nr = ne01*ne02*ne03;
 
-        // row range for this thread
-        const int ir0 = dr*ith;
-        const int ir1 = MIN(ir0 + dr, nr);
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
 
-        for (int ir = ir0; ir < ir1; ++ir) {
-            // src0 indices
-            const int i03 = ir/(ne02*ne01);
-            const int i02 = (ir - i03*ne02*ne01)/ne01;
-            const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
 
-            for (int ic = 0; ic < ne11; ++ic) {
-                // src1 indices
-                const int i13 = i03;
-                const int i12 = i02;
-                const int i11 = ic;
+    ggml_fp16_t * wdata = params->wdata;
 
-                // dst indices
-                const int i0 = i01;
-                const int i1 = i11;
-                const int i2 = i02;
-                const int i3 = i03;
-
-                ggml_vec_dot_f32(ne00,
-                        (float *) ((char *)  dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
-                        (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)),
-                        (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)));
-            }
-        }
-    } else {
-        // parallelize by src1 columns using ggml_vec_mad_f32
-        // each thread has its own work data
-        // during FINALIZE we accumulate all work data into dst
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 indices
+        const int i03 = ir/(ne02*ne01);
+        const int i02 = (ir - i03*ne02*ne01)/ne01;
+        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
 
-        // total columns in src1
-        const int nc = ne10;
+        const int i13 = i03;
+        const int i12 = i02;
 
-        // columns per thread
-        const int dc = (nc + nth - 1)/nth;
+        const int i0 = i01;
+        const int i2 = i02;
+        const int i3 = i03;
 
-        // column range for this thread
-        const int ic0 = dc*ith;
-        const int ic1 = MIN(ic0 + dc, nc);
+        ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
+        ggml_fp16_t * src1_col =                                wdata + (       0 + i12*ne11 + i13*ne12*ne11)*ne00;
 
-        // work data for thread
-        const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
-        float * const wdata = params->wdata;
+        float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
 
-        for (int i13 = 0; i13 < ne13; ++i13) {
-            for (int i12 = 0; i12 < ne12; ++i12) {
-                for (int i11 = 0; i11 < ne11; ++i11) {
-                    for (int ic = ic0; ic < ic1; ++ic) {
-                        // src1 indices
-                        const int i10 = ic;
-
-                        // src0 indices
-                        const int i03 = i13;
-                        const int i02 = i12;
-                        const int i00 = ic;
-
-                        // dst indices
-                        const int i1 = i11;
-                        const int i2 = i12;
-                        const int i3 = i13;
-
-                        assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
-
-                        ggml_vec_mad_f32(ne01,
-                                (float *) (wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0),
-                                (float *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)),
-                               *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13)));
-                    }
-                }
-            }
+        for (int ic = 0; ic < ne11; ++ic) {
+            ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
         }
     }
 
-    //int64_t t1 = ggml_perf_time_us();
+    //int64_t t1 = ggml_time_us();
     //static int64_t acc = 0;
     //acc += t1 - t0;
     //if (t1 - t0 > 10) {
@@ -4604,13 +6320,35 @@ static void ggml_compute_forward_mul_mat_f32(
     //    printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
     //    printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
     //    printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
-    //    printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
 
     //    printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
     //}
 }
 
-static void ggml_compute_forward_mul_mat_f16_f32(
+typedef void (*dequantize_row_q_t)(const void * restrict x, float * restrict y, int k);
+typedef void (*quantize_row_q_t)(const float * restrict x, void * restrict y, int k);
+typedef void (*vec_dot_q_t)(const int n, float * restrict s, const void * restrict x, const void * restrict y);
+
+typedef struct {
+    dequantize_row_q_t dequantize_row_q;
+    quantize_row_q_t   quantize_row_q;
+    vec_dot_q_t        vec_dot_q;
+} quantize_fns_t;
+
+static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
+    [GGML_TYPE_Q4_0] = {
+        .dequantize_row_q = dequantize_row_q4_0,
+        .quantize_row_q   = quantize_row_q4_0,
+        .vec_dot_q        = ggml_vec_dot_q4_0,
+    },
+    [GGML_TYPE_Q4_1] = {
+        .dequantize_row_q = dequantize_row_q4_1,
+        .quantize_row_q   = quantize_row_q4_1,
+        .vec_dot_q        = ggml_vec_dot_q4_1,
+    },
+};
+
+static void ggml_compute_forward_mul_mat_q_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -4632,7 +6370,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
     const int ne1  = dst->ne[1];
     const int ne2  = dst->ne[2];
     const int ne3  = dst->ne[3];
-    const int ne   = ne0*ne1*ne2*ne3;
 
     const int nb00 = src0->nb[0];
     const int nb01 = src0->nb[1];
@@ -4657,8 +6394,13 @@ static void ggml_compute_forward_mul_mat_f16_f32(
     GGML_ASSERT(ne2  == ne12);
     GGML_ASSERT(ne3  == ne13);
 
-    // TODO: we don't support permuted src0
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t) || nb01 == sizeof(ggml_fp16_t));
+    const enum ggml_type type = src0->type;
+    quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
+    vec_dot_q_t      const vec_dot_q      = quantize_fns[type].vec_dot_q;
+
+    // we don't support permuted src0 or src1
+    GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
+    GGML_ASSERT(nb10 == sizeof(float));
 
     // dst cannot be transposed or permuted
     GGML_ASSERT(nb0 == sizeof(float));
@@ -4673,14 +6415,9 @@ static void ggml_compute_forward_mul_mat_f16_f32(
 
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
-    //
-    // nb00 <  nb01 - src0 is transposed
-    //   compute by src0 columns
 
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
-        GGML_ASSERT(nb10 == sizeof(float));
-
         if (params->ith != 0) {
             return;
         }
@@ -4694,58 +6431,29 @@ static void ggml_compute_forward_mul_mat_f16_f32(
         }
 
         float * const wdata = params->wdata;
+        dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
 
         for (int i03 = 0; i03 < ne03; i03++) {
             for (int i02 = 0; i02 < ne02; i02++) {
                 {
-                    int id = 0;
+                    size_t id = 0;
                     for (int i01 = 0; i01 < ne01; ++i01) {
-                        for (int i00 = 0; i00 < ne00; ++i00) {
-                            wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
-                        }
+                        dequantize_row_q((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
+                        id += ne00;
                     }
                 }
 
                 const float * x = wdata;
                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
 
-                //      float * z =                          wdata + ne00*ne01;
-
-                // z = x * yT
-                //{
-                //    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
-                //            ne01, ne11, ne00,
-                //            1.0f, x, ne00,
-                //                  y, ne00,
-                //            0.0f, z, ne11);
-                //}
-
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
-                // transpose z
-                //for (int j = 0; j < ne11; ++j) {
-                //    for (int i = 0; i < ne01; ++i) {
-                //        d[j*ne01 + i] = z[i*ne11 + j];
-                //    }
-                //}
-
-                {
-#if 1
-                    // zT = y * xT
-                    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
-                            ne11, ne01, ne10,
-                            1.0f,    y, ne00,
-                                     x, ne00,
-                            0.0f,    d, ne01);
-#else
-                    // zT = (xT * y)T
-                    cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans,
-                            ne01, ne11, ne10,
-                            1.0f,    x, ne00,
-                                     y, ne00,
-                            0.0f,    d, ne01);
-#endif
-                }
+                // zT = y * xT
+                cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+                        ne11, ne01, ne10,
+                        1.0f,    y, ne10,
+                                 x, ne10,
+                        0.0f,    d, ne01);
             }
         }
 
@@ -4756,150 +6464,62 @@ static void ggml_compute_forward_mul_mat_f16_f32(
 #endif
 
     if (params->type == GGML_TASK_INIT) {
-        if (nb01 >= nb00) {
-            ggml_fp16_t * const wdata = params->wdata;
-
-            int id = 0;
-            for (int i13 = 0; i13 < ne13; ++i13) {
-                for (int i12 = 0; i12 < ne12; ++i12) {
-                    for (int i11 = 0; i11 < ne11; ++i11) {
-                        for (int i10 = 0; i10 < ne10; ++i10) {
-                            wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
-                        }
-                    }
+        char * wdata = params->wdata;
+        const size_t row_size = ne10*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
+
+        for (int i13 = 0; i13 < ne13; ++i13) {
+            for (int i12 = 0; i12 < ne12; ++i12) {
+                for (int i11 = 0; i11 < ne11; ++i11) {
+                    quantize_row_q((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
+                    wdata += row_size;
                 }
             }
-
-            GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize);
-
-            return;
         }
 
-        // TODO: fix this memset (wsize is overestimated)
-        memset(params->wdata, 0, params->wsize);
         return;
     }
 
     if (params->type == GGML_TASK_FINALIZE) {
-        if (nb01 >= nb00) {
-            return;
-        }
-
-        // TODO: fix this memset (wsize is overestimated)
-        //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth);
-
-        ggml_fp16_t * const wdata = params->wdata;
-
-        // cols per thread
-        const int dc = (ne + nth - 1)/nth;
-
-        // col range for this thread
-        const int ic0 = dc*ith;
-        const int ic1 = MIN(ic0 + dc, ne);
-
-        for (int i = ic0; i < ic1; ++i) {
-            ((float *) dst->data)[i] = GGML_FP16_TO_FP32(wdata[i]);
-        }
-
-        for (int k = 1; k < nth; k++) {
-            for (int i = ic0; i < ic1; ++i) {
-                ((float *) dst->data)[i] += GGML_FP16_TO_FP32(wdata[(ne + CACHE_LINE_SIZE_F32)*k + i]);
-            }
-        }
-
         return;
     }
 
-    if (nb01 >= nb00) {
-        // fp16 -> half the size, so divide by 2
-        // TODO: do not support transposed src1
-        assert(nb10/2 == sizeof(ggml_fp16_t));
-
-        // parallelize by src0 rows using ggml_vec_dot_f16
-
-        // total rows in src0
-        const int nr = ne01*ne02*ne03;
-
-        // rows per thread
-        const int dr = (nr + nth - 1)/nth;
-
-        // row range for this thread
-        const int ir0 = dr*ith;
-        const int ir1 = MIN(ir0 + dr, nr);
-
-        ggml_fp16_t * wdata = params->wdata;
-
-        for (int ir = ir0; ir < ir1; ++ir) {
-            // src0 indices
-            const int i03 = ir/(ne02*ne01);
-            const int i02 = (ir - i03*ne02*ne01)/ne01;
-            const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-            const int i13 = i03;
-            const int i12 = i02;
-
-            const int i0 = i01;
-            const int i2 = i02;
-            const int i3 = i03;
-
-            ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
-            ggml_fp16_t * src1_col =                                wdata + (       0 + i12*ne11 + i13*ne12*ne11)*ne00;
-
-            float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
+    // parallelize by src0 rows using ggml_vec_dot_q
 
-            assert(ne00 % 32 == 0);
+    // total rows in src0
+    const int nr = ne01*ne02*ne03;
 
-            for (int ic = 0; ic < ne11; ++ic) {
-                ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
-            }
-        }
-    } else {
-        // parallelize by src1 columns using ggml_vec_mad_f16
-        // each thread has its own work data
-        // during FINALIZE we accumulate all work data into dst
-
-        // total columns in src1
-        const int nc = ne10;
-
-        // columns per thread
-        const int dc = (nc + nth - 1)/nth;
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
 
-        // column range for this thread
-        const int ic0 = dc*ith;
-        const int ic1 = MIN(ic0 + dc, nc);
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
 
-        // work data for thread
-        const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
-        ggml_fp16_t * const wdata = params->wdata;
+    void * wdata = params->wdata;
+    const size_t row_size = ne00*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
 
-        for (int i13 = 0; i13 < ne13; ++i13) {
-            for (int i12 = 0; i12 < ne12; ++i12) {
-                for (int i11 = 0; i11 < ne11; ++i11) {
-                    // dst indices
-                    const int i1 = i11;
-                    const int i2 = i12;
-                    const int i3 = i13;
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 indices
+        const int i03 = ir/(ne02*ne01);
+        const int i02 = (ir - i03*ne02*ne01)/ne01;
+        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
 
-                    ggml_fp16_t * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0;
+        const int i13 = i03;
+        const int i12 = i02;
 
-                    for (int ic = ic0; ic < ic1; ++ic) {
-                        // src1 indices
-                        const int i10 = ic;
+        const int i0 = i01;
+        const int i2 = i02;
+        const int i3 = i03;
 
-                        // src0 indices
-                        const int i03 = i13;
-                        const int i02 = i12;
-                        const int i00 = ic;
+        void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
+        char * src1_col =          ((char *)      wdata + (      (0 + i12*ne11 + i13*ne12*ne11)*row_size));
 
-                        assert(sizeof(ggml_fp16_t)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
+        float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
 
-                        ggml_fp16_t * src0_col =  (ggml_fp16_t *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03));
-                        float         src1_val = *      (float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+        assert(ne00 % 32 == 0);
 
-                        ggml_vec_mad_f16(ne01, dst_row, src0_col, src1_val);
-                    }
-                }
-            }
+        for (int ic = 0; ic < ne11; ++ic) {
+            vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
         }
     }
 
@@ -4922,6 +6542,11 @@ static void ggml_compute_forward_mul_mat(
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+            {
+                ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
+            } break;
         case GGML_TYPE_F16:
             {
                 ggml_compute_forward_mul_mat_f16_f32(params, src0, src1, dst);
@@ -4935,9 +6560,37 @@ static void ggml_compute_forward_mul_mat(
         case GGML_TYPE_I32:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
+
+#if 0
+    if (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_Q4_1) {
+        static int first = 8;
+        printf("src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
+        printf("src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
+        printf("dst:  ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
+        if (first) {
+            --first;
+        } else {
+            for (int k = 0; k < dst->ne[1]; ++k) {
+                for (int j = 0; j < dst->ne[0]/16; ++j) {
+                    for (int i = 0; i < 16; ++i) {
+                        printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
+                    }
+                    printf("\n");
+                }
+                printf("\n");
+            }
+            printf("\n");
+            exit(0);
+        }
+    } else {
+        printf("aaaa src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
+        printf("aaaa src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
+        printf("aaaa dst:  ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
+    }
+#endif
 }
 
 // ggml_compute_forward_scale
@@ -4987,13 +6640,15 @@ static void ggml_compute_forward_scale(
             {
                 ggml_compute_forward_scale_f32(params, src0, src1, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -5051,6 +6706,35 @@ static void ggml_compute_forward_transpose(
 
 // ggml_compute_forward_get_rows
 
+static void ggml_compute_forward_get_rows_q(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nelements(src1);
+    const enum ggml_type type = src0->type;
+    dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
+
+    assert( dst->ne[0] == nc);
+    assert( dst->ne[1] == nr);
+    assert(src0->nb[0] == GGML_TYPE_SIZE[type]);
+
+    for (int i = 0; i < nr; ++i) {
+        const int r = ((int32_t *) src1->data)[i];
+
+        dequantize_row_q(
+                (const void *) ((char *) src0->data + r*src0->nb[1]),
+                     (float *) ((char *)  dst->data + i*dst->nb[1]), nc);
+    }
+}
+
 static void ggml_compute_forward_get_rows_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
@@ -5112,6 +6796,11 @@ static void ggml_compute_forward_get_rows(
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+            {
+                ggml_compute_forward_get_rows_q(params, src0, src1, dst);
+            } break;
         case GGML_TYPE_F16:
             {
                 ggml_compute_forward_get_rows_f16(params, src0, src1, dst);
@@ -5125,9 +6814,27 @@ static void ggml_compute_forward_get_rows(
         case GGML_TYPE_I32:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
+
+    //static bool first = true;
+    //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
+    //if (first) {
+    //    first = false;
+    //} else {
+    //    for (int k = 0; k < dst->ne[1]; ++k) {
+    //        for (int j = 0; j < dst->ne[0]/16; ++j) {
+    //            for (int i = 0; i < 16; ++i) {
+    //                printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
+    //            }
+    //            printf("\n");
+    //        }
+    //        printf("\n");
+    //    }
+    //    printf("\n");
+    //    exit(0);
+    //}
 }
 
 // ggml_compute_forward_diag_mask_inf
@@ -5178,13 +6885,15 @@ static void ggml_compute_forward_diag_mask_inf(
             {
                 ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -5223,6 +6932,7 @@ static void ggml_compute_forward_soft_max_f32(
 
 #ifndef NDEBUG
         for (int i = 0; i < nc; ++i) {
+            //printf("p[%d] = %f\n", i, p[i]);
             assert(!isnan(p[i]));
         }
 #endif
@@ -5241,12 +6951,12 @@ static void ggml_compute_forward_soft_max_f32(
                 ggml_fp16_t s = GGML_FP32_TO_FP16(p[i] - max);
                 memcpy(&scvt, &s, sizeof(scvt));
                 const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
-                sum += val;
+                sum += (ggml_float)val;
                 p[i] = val;
             }
         }
 
-        assert(sum > 0.0f);
+        assert(sum > 0.0);
 
         sum = 1.0/sum;
         ggml_vec_scale_f32(nc, p, sum);
@@ -5269,13 +6979,15 @@ static void ggml_compute_forward_soft_max(
             {
                 ggml_compute_forward_soft_max_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -5320,16 +7032,16 @@ static void ggml_compute_forward_rope_f32(
             const int p = (mode == 0 ? n_past + i2 : i2);
             for (int i1 = 0; i1 < ne1; i1++) {
                 for (int i0 = 0; i0 < n_dims; i0 += 2) {
-                    const double theta = pow(10000.0, ((double)-i0)/n_dims);
+                    const float theta = powf(10000.0, ((float)-i0)/n_dims);
 
-                    const double cos_theta = cos(p*theta);
-                    const double sin_theta = sin(p*theta);
+                    const float cos_theta = cosf(p*theta);
+                    const float sin_theta = sinf(p*theta);
 
                     const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
                           float * dst_data  = (float *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
 
-                    double x0 = src[0];
-                    double x1 = src[1];
+                    const float x0 = src[0];
+                    const float x1 = src[1];
 
                     dst_data[0] = x0*cos_theta - x1*sin_theta;
                     dst_data[1] = x0*sin_theta + x1*cos_theta;
@@ -5339,23 +7051,84 @@ static void ggml_compute_forward_rope_f32(
     }
 }
 
+static void ggml_compute_forward_rope_f16(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(src1->type == GGML_TYPE_I32);
+    assert(ggml_nelements(src1) == 3);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n_past = ((int32_t *) src1->data)[0];
+    const int n_dims = ((int32_t *) src1->data)[1];
+    const int mode   = ((int32_t *) src1->data)[2];
+
+    //const int ne0 = src0->ne[0];
+    const int ne1 = src0->ne[1];
+    const int ne2 = src0->ne[2];
+    const int ne3 = src0->ne[3];
+
+    const int nb0 = src0->nb[0];
+    const int nb1 = src0->nb[1];
+    const int nb2 = src0->nb[2];
+    const int nb3 = src0->nb[3];
+
+    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
+    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
+
+    assert(nb0 == sizeof(ggml_fp16_t));
+
+    for (int i3 = 0; i3 < ne3; i3++) {
+        for (int i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
+            const int p = (mode == 0 ? n_past + i2 : i2);
+            for (int i1 = 0; i1 < ne1; i1++) {
+                for (int i0 = 0; i0 < n_dims; i0 += 2) {
+                    const float theta = powf(10000.0, ((float)-i0)/n_dims);
+
+                    const float cos_theta = cosf(p*theta);
+                    const float sin_theta = sinf(p*theta);
+
+                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+                          ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+                    const float x0 = ggml_fp16_to_fp32(src[0]);
+                    const float x1 = ggml_fp16_to_fp32(src[1]);
+
+                    dst_data[0] = ggml_fp32_to_fp16(x0*cos_theta - x1*sin_theta);
+                    dst_data[1] = ggml_fp32_to_fp16(x0*sin_theta + x1*cos_theta);
+                }
+            }
+        }
+    }
+}
+
 static void ggml_compute_forward_rope(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_rope_f16(params, src0, src1, dst);
+            } break;
         case GGML_TYPE_F32:
             {
                 ggml_compute_forward_rope_f32(params, src0, src1, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -5616,6 +7389,8 @@ static void ggml_compute_forward_conv_1d_1s(
             {
                 ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
@@ -5882,6 +7657,8 @@ static void ggml_compute_forward_conv_1d_2s(
             {
                 ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
@@ -5996,7 +7773,7 @@ static void ggml_compute_forward_flash_attn_f32(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    const float scale = 1.0/sqrt((double) D);
+    const float scale = 1.0f/sqrtf(D);
 
     //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
 
@@ -6043,7 +7820,7 @@ static void ggml_compute_forward_flash_attn_f32(
             float max = -INFINITY;
             ggml_vec_max_f32(M, &max, S);
 
-            float sum = 0.0f;
+            ggml_float sum = 0.0;
             {
 #ifdef GGML_SOFT_MAX_ACCELERATE
                 max = -max;
@@ -6064,7 +7841,7 @@ static void ggml_compute_forward_flash_attn_f32(
                             ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
                             memcpy(&scvt[j], &s, sizeof(uint16_t));
                             const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
-                            sump[j] += val;
+                            sump[j] += (ggml_float)val;
                             SS[j] = val;
                         }
                     }
@@ -6076,7 +7853,7 @@ static void ggml_compute_forward_flash_attn_f32(
 #endif
             }
 
-            assert(sum > 0.0f);
+            assert(sum > 0.0);
 
             sum = 1.0/sum;
             ggml_vec_scale_f32(M, S, sum);
@@ -6205,7 +7982,7 @@ static void ggml_compute_forward_flash_attn_f16(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    const float scale = 1.0/sqrt((double) D);
+    const float scale = 1.0f/sqrtf(D);
 
     //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
 
@@ -6269,7 +8046,7 @@ static void ggml_compute_forward_flash_attn_f16(
             float max = -INFINITY;
             ggml_vec_max_f32(M, &max, S);
 
-            float sum = 0.0f;
+            ggml_float sum = 0.0;
             {
 #ifdef GGML_SOFT_MAX_ACCELERATE
                 max = -max;
@@ -6290,7 +8067,7 @@ static void ggml_compute_forward_flash_attn_f16(
                             ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
                             memcpy(&scvt[j], &s, sizeof(uint16_t));
                             const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
-                            sump[j] += val;
+                            sump[j] += (ggml_float)val;
                             SS[j] = val;
                         }
                     }
@@ -6302,7 +8079,7 @@ static void ggml_compute_forward_flash_attn_f16(
 #endif
             }
 
-            assert(sum > 0.0f);
+            assert(sum > 0.0);
 
             sum = 1.0/sum;
             ggml_vec_scale_f32(M, S, sum);
@@ -6365,12 +8142,14 @@ static void ggml_compute_forward_flash_attn(
             {
                 ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -6574,12 +8353,14 @@ static void ggml_compute_forward_flash_ff(
             {
                 GGML_ASSERT(false); // TODO
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -6587,7 +8368,7 @@ static void ggml_compute_forward_flash_ff(
 /////////////////////////////////
 
 static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
-    assert(params);
+    GGML_ASSERT(params);
 
     switch (tensor->op) {
         case GGML_OP_DUP:
@@ -6654,10 +8435,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_gelu(params, tensor->src0, tensor);
             } break;
+        case GGML_OP_SILU:
+            {
+                ggml_compute_forward_silu(params, tensor->src0, tensor);
+            } break;
         case GGML_OP_NORM:
             {
                 ggml_compute_forward_norm(params, tensor->src0, tensor);
             } break;
+        case GGML_OP_RMS_NORM:
+            {
+                ggml_compute_forward_rms_norm(params, tensor->src0, tensor);
+            } break;
         case GGML_OP_MUL_MAT:
             {
                 ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
@@ -6835,7 +8624,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             } break;
         case GGML_OP_MEAN:
             {
-                assert(false); // TODO: implement
+                GGML_ASSERT(false); // TODO: implement
             } break;
         case GGML_OP_REPEAT:
             {
@@ -6890,17 +8679,25 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             } break;
         case GGML_OP_GELU:
             {
-                assert(false); // TODO: not implemented
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
+        case GGML_OP_SILU:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
             } break;
         case GGML_OP_NORM:
             {
-                assert(false); // TODO: not implemented
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
+        case GGML_OP_RMS_NORM:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
             } break;
         case GGML_OP_MUL_MAT:
             {
                 if (src0->grad) {
                     // TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad);
-                    assert(false);
+                    GGML_ASSERT(false);
                 }
                 if (src1->grad) {
                     src1->grad =
@@ -7016,12 +8813,12 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
 
     if (node->op == GGML_OP_NONE && node->grad == NULL) {
         // reached a leaf node, not part of the gradient graph (e.g. a constant)
-        assert(cgraph->n_leafs < GGML_MAX_NODES);
+        GGML_ASSERT(cgraph->n_leafs < GGML_MAX_NODES);
 
         cgraph->leafs[cgraph->n_leafs] = node;
         cgraph->n_leafs++;
     } else {
-        assert(cgraph->n_nodes < GGML_MAX_NODES);
+        GGML_ASSERT(cgraph->n_nodes < GGML_MAX_NODES);
 
         cgraph->nodes[cgraph->n_nodes] = node;
         cgraph->grads[cgraph->n_nodes] = node->grad;
@@ -7045,7 +8842,7 @@ static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_ten
 
     if (n_new > 0) {
         // the last added node should always be starting point
-        assert(cgraph->nodes[cgraph->n_nodes - 1] == tensor);
+        GGML_ASSERT(cgraph->nodes[cgraph->n_nodes - 1] == tensor);
     }
 }
 
@@ -7076,7 +8873,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
 struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) {
     struct ggml_cgraph result = *gf;
 
-    assert(gf->n_nodes > 0);
+    GGML_ASSERT(gf->n_nodes > 0);
 
     // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph
     if (keep) {
@@ -7239,10 +9036,6 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
 }
 
 void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
-    if (cgraph->n_threads <= 0) {
-        cgraph->n_threads = 8;
-    }
-
     const int n_threads = cgraph->n_threads;
 
     struct ggml_compute_state_shared state_shared = {
@@ -7275,7 +9068,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
             };
 
             int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
-            assert(rc == 0);
+            GGML_ASSERT(rc == 0);
             UNUSED(rc);
         }
     }
@@ -7317,7 +9110,12 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     {
                         node->n_tasks = n_threads;
                     } break;
+                case GGML_OP_SILU:
+                    {
+                        node->n_tasks = n_threads;
+                    } break;
                 case GGML_OP_NORM:
+                case GGML_OP_RMS_NORM:
                     {
                         node->n_tasks = n_threads;
                     } break;
@@ -7334,32 +9132,35 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
 
                         size_t cur = 0;
 
-                        // TODO: better way to determine if the matrix is transposed
-                        if (node->src0->nb[1] < node->src0->nb[0]) {
-                            cur = ggml_nbytes(node)*node->n_tasks; // TODO: this can become (n_tasks-1)
-                        } else {
-                            if (node->src0->type == GGML_TYPE_F16 &&
-                                node->src1->type == GGML_TYPE_F32) {
+                        if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
-                                if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
-                                    node->n_tasks = 1; // TODO: this actually is doing nothing
-                                                       //       the threads are still spinning
-                                    cur = sizeof(float)*(node->src0->ne[0]*node->src0->ne[1]);
-                                    //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]);
-                                    //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
-                                    //printf("cur = %zu\n", cur);
-                                } else {
-                                    cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1);
-                                }
+                            if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
+                                node->n_tasks = 1; // TODO: this actually is doing nothing
+                                                   //       the threads are still spinning
+                                cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
+                                //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]);
+                                //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
+                                //printf("cur = %zu\n", cur);
+                            } else {
+                                cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
+                            }
 #else
-                                cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1);
+                            cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
 #endif
-                            } else if (node->src0->type == GGML_TYPE_F32 &&
-                                       node->src1->type == GGML_TYPE_F32) {
-                                cur = 0;
-                            } else {
-                                GGML_ASSERT(false);
+                        } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
+                            cur = 0;
+                        } else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) {
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+                            if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
+                                node->n_tasks = 1;
+                                cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
+                            } else
+#endif
+                            {
+                                cur = GGML_TYPE_SIZE[node->src0->type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[node->src0->type];
                             }
+                        } else {
+                            GGML_ASSERT(false);
                         }
 
                         work_size = MAX(work_size, cur);
@@ -7460,13 +9261,13 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     } break;
                 case GGML_OP_COUNT:
                     {
-                        assert(false);
+                        GGML_ASSERT(false);
                     } break;
             }
         }
 
         if (cgraph->work != NULL && work_size > cgraph->work_size) {
-            assert(false); // TODO: better handling
+            GGML_ASSERT(false); // TODO: better handling
         }
 
         if (work_size > 0 && cgraph->work == NULL) {
@@ -7632,7 +9433,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
 
         for (int j = 0; j < n_threads - 1; j++) {
             int rc = ggml_thread_join(workers[j].thrd, NULL);
-            assert(rc == 0);
+            GGML_ASSERT(rc == 0);
             UNUSED(rc);
         }
 
@@ -7739,7 +9540,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
     char color[16];
 
     FILE * fp = fopen(filename, "w");
-    assert(fp);
+    GGML_ASSERT(fp);
 
     fprintf(fp, "digraph G {\n");
     fprintf(fp, "  newrank = true;\n");
@@ -7787,7 +9588,7 @@ label=\"%d [%d, %d] | <x>%s",
             fprintf(fp, "  \"%p\" [ \
 style = filled; fillcolor = %s; shape = record; \
 label=\"<x>%.1e\"; ]\n",
-                    (void *) node, color, ggml_get_f32_1d(node, 0));
+                    (void *) node, color, (double)ggml_get_f32_1d(node, 0));
         } else {
             fprintf(fp, "  \"%p\" [ \
 style = filled; fillcolor = %s; shape = record; \
@@ -7897,7 +9698,7 @@ static enum ggml_opt_result ggml_opt_adam(
         struct ggml_tensor * f,
         struct ggml_cgraph * gf,
         struct ggml_cgraph * gb) {
-    assert(ggml_is_scalar(f));
+    GGML_ASSERT(ggml_is_scalar(f));
 
     gf->n_threads = params.n_threads;
     gb->n_threads = params.n_threads;
@@ -7911,7 +9712,7 @@ static enum ggml_opt_result ggml_opt_adam(
         if (gf->nodes[i]->is_param) {
             GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op);
 
-            assert(np < GGML_MAX_PARAMS);
+            GGML_ASSERT(np < GGML_MAX_PARAMS);
 
             ps[np++] = gf->nodes[i];
             nx += ggml_nelements(gf->nodes[i]);
@@ -8025,7 +9826,7 @@ static enum ggml_opt_result ggml_opt_adam(
             if (params.past <= t) {
                 const float rate = (pf[t%params.past] - fx)/fx;
 
-                if (fabs(rate) < params.delta) {
+                if (fabsf(rate) < params.delta) {
                     return GGML_OPT_OK;
                 }
             }
@@ -8104,7 +9905,7 @@ static enum ggml_opt_result linesearch_backtracking(
     const float dec = 0.5f;
     const float inc = 2.1f;
 
-    if (*step <= 0.) {
+    if (*step <= 0.f) {
         return GGML_LINESEARCH_INVALID_PARAMETERS;
     }
 
@@ -8192,7 +9993,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
         struct ggml_cgraph * gb) {
     if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE ||
         params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) {
-        if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1. <= params.lbfgs.wolfe) {
+        if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1.f <= params.lbfgs.wolfe) {
             return GGML_OPT_INVALID_WOLFE;
         }
     }
@@ -8211,7 +10012,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
         if (gf->nodes[i]->is_param) {
             GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op);
 
-            assert(np < GGML_MAX_PARAMS);
+            GGML_ASSERT(np < GGML_MAX_PARAMS);
 
             ps[np++] = gf->nodes[i];
             nx += ggml_nelements(gf->nodes[i]);
@@ -8313,8 +10114,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
 
         GGML_PRINT_DEBUG("f = %10.6f\n", ggml_get_f32_1d(f, 0));
 
-        if (xnorm < 1.0) {
-            xnorm = 1.0;
+        if (xnorm < 1.0f) {
+            xnorm = 1.0f;
         }
         if (gnorm/xnorm <= params.lbfgs.eps) {
             // converged
@@ -8327,7 +10128,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
             if (params.past <= k) {
                 const float rate = (pf[k%params.past] - fx)/fx;
 
-                if (fabs(rate) < params.delta) {
+                if (fabsf(rate) < params.delta) {
                     return GGML_OPT_OK;
                 }
             }
@@ -8523,6 +10324,54 @@ enum ggml_opt_result ggml_opt(
 
 ////////////////////////////////////////////////////////////////////////////////
 
+size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
+    assert(k % QK == 0);
+    const int nb = k / QK;
+
+    for (int j = 0; j < n; j += k) {
+        block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK;
+
+        quantize_row_q4_0_reference(src + j, y, k);
+
+        for (int i = 0; i < nb; i++) {
+            for (int l = 0; l < QK; l += 2) {
+                const uint8_t vi0 = y[i].qs[l/2] & 0xF;
+                const uint8_t vi1 = y[i].qs[l/2] >> 4;
+
+                hist[vi0]++;
+                hist[vi1]++;
+            }
+        }
+    }
+
+    return (n/QK*sizeof(block_q4_0));
+}
+
+size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) {
+    assert(k % QK == 0);
+    const int nb = k / QK;
+
+    for (int j = 0; j < n; j += k) {
+        block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK;
+
+        quantize_row_q4_1_reference(src + j, y, k);
+
+        for (int i = 0; i < nb; i++) {
+            for (int l = 0; l < QK; l += 2) {
+                const uint8_t vi0 = y[i].qs[l/2] & 0xF;
+                const uint8_t vi1 = y[i].qs[l/2] >> 4;
+
+                hist[vi0]++;
+                hist[vi1]++;
+            }
+        }
+    }
+
+    return (n/QK*sizeof(block_q4_1));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
 int ggml_cpu_has_avx(void) {
 #if defined(__AVX__)
     return 1;
index bb7dd8d8f77390d48b41cb0c4cf079626308e2ef..be7b038dfb1629b41401b854ee97bc31409cdc4d 100644 (file)
 
 #include <sys/time.h>
 
-#ifdef __ARM_NEON
+#if defined(__ARM_NEON)
 #include "arm_neon.h"
+#elif defined(__AVX__) || defined(__AVX2__)
+#include "immintrin.h"
 #endif
 
 #ifndef MIN
@@ -26,8 +28,12 @@ const int M = 1280;
 const int N = 1536;
 const int K = 1280;
 
-const int QK = 64;
-#define QB 7
+//const int M = 64;
+//const int N = 64;
+//const int K = 64;
+
+#define QK 64
+#define QB 4
 
 //#define GGML_GQ_USE_FP16_SCALE
 
@@ -41,8 +47,12 @@ const int QK = 64;
 #define GGML_GQ_TO_FP32(x) (x)
 #endif
 
-#define gq_quant_t uint64_t
 #define gq_t_bits 64
+#define gq_quant_t uint64_t
+
+float frand() {
+    return (float) rand() / (float) RAND_MAX;
+}
 
 uint64_t get_time_us() {
     struct timeval tv;
@@ -50,6 +60,47 @@ uint64_t get_time_us() {
     return tv.tv_sec * 1000000 + tv.tv_usec;
 }
 
+#if defined(__AVX2__)
+// horizontally reduce 8 32-bit integers
+static inline uint32_t _mm256_hadd_epi32_gg(__m256i v) {
+    __m128i v0 = _mm256_extractf128_si256(v, 0);
+    __m128i v1 = _mm256_extractf128_si256(v, 1);
+
+    v0 = _mm_add_epi32(v0, v1);
+
+    v1 = _mm_shuffle_epi32(v0, 0x0e);
+    v0 = _mm_add_epi32(v0, v1);
+
+    v1 = _mm_shuffle_epi32(v0, 0x01);
+    v0 = _mm_add_epi32(v0, v1);
+
+    return _mm_cvtsi128_si32(v0);
+}
+
+//static inline float _mm256_hadd_epi32_gg(__m256i v) {
+//    const __m256 v0 = _mm256_cvtepi32_ps(v);
+//    const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(v0), _mm256_extractf128_ps(v0, 1));
+//    const __m128 t1 = _mm_hadd_ps(t0, t0);
+//
+//    return _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
+//}
+
+// horizontally reduce 32 8-bit integers
+static inline int32_t _mm256_hadd_epi8_gg(__m256i v0) {
+    __m256i v1 = _mm256_maddubs_epi16(v0, _mm256_set1_epi8(1));
+    __m256i v2 = _mm256_madd_epi16   (v1, _mm256_set1_epi16(1));
+
+    return _mm256_hadd_epi32_gg(v2);
+}
+
+static inline float _mm256_hadd_ps_gg(__m256 v) {
+    const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1));
+    const __m128 t1 = _mm_hadd_ps(t0, t0);
+
+    return _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
+}
+#endif
+
 //
 // naive implementation
 //
@@ -74,6 +125,21 @@ void mul_mat_f32_naive(
 // method 1
 //
 
+static inline int quantize_1_blocks_per_row(int k) {
+    return k/QK;
+}
+
+static inline int quantize_1_quants_per_block() {
+    return QK/gq_t_bits;
+}
+
+static inline int quantize_1_row_size(int k) {
+    const int nb = quantize_1_blocks_per_row(k);
+    const int nq = quantize_1_quants_per_block();
+
+    return nb*(2*sizeof(gq_scale_t) + nq*QB*sizeof(gq_quant_t));
+}
+
 void quantize_1(const float * src, void * dst, int n, int k) {
     char * p0 = dst;
 
@@ -215,6 +281,7 @@ void mul_mat_gq_1(
 
 //
 // method 2
+// n-bit quantization (2nd attempt)
 //
 
 static inline int quantize_2_blocks_per_row(int k) {
@@ -244,15 +311,41 @@ void quantize_2_row(const float * restrict src, void * restrict dst, int k) {
 
     gq_quant_t pp[QB];
 
+    static const int32_t sh[32] = {
+        0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15,
+        16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
+    };
+
     for (int i = 0; i < nb; i++) {
         float min = FLT_MAX;
         float max = -FLT_MAX;
 
-        for (int l = 0; l < QK; l++) {
-            const float v = src[i*QK + l];
-            if (v < min) min = v;
-            if (v > max) max = v;
+#ifdef __ARM_NEON
+        {
+            float32x4_t minv = vdupq_n_f32(FLT_MAX);
+            float32x4_t maxv = vdupq_n_f32(-FLT_MAX);
+
+            for (int l = 0; l < QK; l += 4) {
+                float32x4_t v = vld1q_f32(src + i*QK + l);
+                minv = vminq_f32(minv, v);
+                maxv = vmaxq_f32(maxv, v);
+            }
+
+            float32x2_t minv32 = vpmin_f32(vget_low_f32(minv), vget_high_f32(minv));
+            float32x2_t maxv32 = vpmax_f32(vget_low_f32(maxv), vget_high_f32(maxv));
+
+            min = MIN(vget_lane_f32(minv32, 0), vget_lane_f32(minv32, 1));
+            max = MAX(vget_lane_f32(maxv32, 0), vget_lane_f32(maxv32, 1));
         }
+#else
+        {
+            for (int l = 0; l < QK; l++) {
+                const float v = src[i*QK + l];
+                if (v < min) min = v;
+                if (v > max) max = v;
+            }
+        }
+#endif
 
         const float d = (max - min) / ((1 << QB) - 1);
         const float id = d ? 1.0/d : 0.0;
@@ -263,18 +356,150 @@ void quantize_2_row(const float * restrict src, void * restrict dst, int k) {
         for (int s = 0; s < nq; ++s) {
             memset(pp, 0, sizeof(pp));
 
+#if 1
             for (int l = 0; l < gq_t_bits; l++) {
                 const   float v = src[i*QK + s*gq_t_bits + l];
-                const uint8_t q = (v - min)*id;
+                const uint8_t q = (v - min)*id + frand();
 
                 for (int b = 0; b < QB; b++) {
                     pp[b] |= q & (1 << b) ? (1ULL << l) : 0;
                 }
             }
+#elif defined(__ARM_NEON)
+#if 1
+            {
+                uint32_t ppt[2*4*QB];
+
+                float32x4_t minv = vdupq_n_f32(min);
+                float32x4_t idv  = vdupq_n_f32(id);
+
+                assert(gq_t_bits % 16 == 0);
+
+                uint32x4_t p0[QB] = { vdupq_n_u32(0) };
+                uint32x4_t p1[QB] = { vdupq_n_u32(0) };
+
+                for (int l = 0; l < gq_t_bits; l += 16) {
+                    float32x4_t v0 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 0);
+                    float32x4_t v1 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 4);
+                    float32x4_t v2 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 8);
+                    float32x4_t v3 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 12);
+
+                    v0 = vsubq_f32(v0, minv);
+                    v1 = vsubq_f32(v1, minv);
+                    v2 = vsubq_f32(v2, minv);
+                    v3 = vsubq_f32(v3, minv);
+
+                    v0 = vmulq_f32(v0, idv);
+                    v1 = vmulq_f32(v1, idv);
+                    v2 = vmulq_f32(v2, idv);
+                    v3 = vmulq_f32(v3, idv);
+
+#if 1
+                    v0[0] += frand(); v0[1] += frand(); v0[2] += frand(); v0[3] += frand();
+                    v1[0] += frand(); v1[1] += frand(); v1[2] += frand(); v1[3] += frand();
+                    v2[0] += frand(); v2[1] += frand(); v2[2] += frand(); v2[3] += frand();
+                    v3[0] += frand(); v3[1] += frand(); v3[2] += frand(); v3[3] += frand();
+#endif
+
+                    uint32x4_t q0 = vcvtq_u32_f32(v0);
+                    uint32x4_t q1 = vcvtq_u32_f32(v1);
+                    uint32x4_t q2 = vcvtq_u32_f32(v2);
+                    uint32x4_t q3 = vcvtq_u32_f32(v3);
+
+                    for (int b = 0; b < QB; ++b) {
+                        uint32x4_t m = vdupq_n_u32(1 << b);
+                        uint32x4_t r = vdupq_n_u32(-b);
+
+                        if (l < 32) {
+                            p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q0, m), r), vld1q_s32(sh + l + 0)));
+                            p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q1, m), r), vld1q_s32(sh + l + 4)));
+                            p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q2, m), r), vld1q_s32(sh + l + 8)));
+                            p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q3, m), r), vld1q_s32(sh + l + 12)));
+                        } else {
+                            p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q0, m), r), vld1q_s32(sh + l - 32)));
+                            p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q1, m), r), vld1q_s32(sh + l - 28)));
+                            p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q2, m), r), vld1q_s32(sh + l - 24)));
+                            p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q3, m), r), vld1q_s32(sh + l - 20)));
+                        }
+                    }
+                }
+
+#if QB == 4
+                vst1q_u32((uint32_t *) ppt + 0,  p0[0]);
+                vst1q_u32((uint32_t *) ppt + 4,  p1[0]);
+                vst1q_u32((uint32_t *) ppt + 8,  p0[1]);
+                vst1q_u32((uint32_t *) ppt + 12, p1[1]);
+                vst1q_u32((uint32_t *) ppt + 16, p0[2]);
+                vst1q_u32((uint32_t *) ppt + 20, p1[2]);
+                vst1q_u32((uint32_t *) ppt + 24, p0[3]);
+                vst1q_u32((uint32_t *) ppt + 28, p1[3]);
+
+                pp[0] = (ppt[0]  | ppt[1]  | ppt[2]  | ppt[3] ) | ((uint64_t) (ppt[4]  | ppt[5]  | ppt[6]  | ppt[7]) ) << 32;
+                pp[1] = (ppt[8]  | ppt[9]  | ppt[10] | ppt[11]) | ((uint64_t) (ppt[12] | ppt[13] | ppt[14] | ppt[15])) << 32;
+                pp[2] = (ppt[16] | ppt[17] | ppt[18] | ppt[19]) | ((uint64_t) (ppt[20] | ppt[21] | ppt[22] | ppt[23])) << 32;
+                pp[3] = (ppt[24] | ppt[25] | ppt[26] | ppt[27]) | ((uint64_t) (ppt[28] | ppt[29] | ppt[30] | ppt[31])) << 32;
+#else
+                for (int b = 0; b < QB; ++b) {
+                    vst1q_u32((uint32_t *) ppt + 0,  p0[b]);
+                    vst1q_u32((uint32_t *) ppt + 4,  p1[b]);
+
+                    pp[b] = (ppt[0] | ppt[1] | ppt[2] | ppt[3]) | ((uint64_t) (ppt[4] | ppt[5] | ppt[6] | ppt[7])) << 32;
+                }
+#endif
+            }
+#else
+            // less optimal SIMD
+            {
+                float32x4_t minv = vdupq_n_f32(min);
+                float32x4_t idv  = vdupq_n_f32(id);
+
+                assert(gq_t_bits == 64);
+                uint8_t qq[gq_t_bits];
+
+                for (int l = 0; l < gq_t_bits; l += 16) {
+                    float32x4_t v0 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 0);
+                    float32x4_t v1 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 4);
+                    float32x4_t v2 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 8);
+                    float32x4_t v3 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 12);
+
+                    v0 = vsubq_f32(v0, minv);
+                    v1 = vsubq_f32(v1, minv);
+                    v2 = vsubq_f32(v2, minv);
+                    v3 = vsubq_f32(v3, minv);
+
+                    v0 = vmulq_f32(v0, idv);
+                    v1 = vmulq_f32(v1, idv);
+                    v2 = vmulq_f32(v2, idv);
+                    v3 = vmulq_f32(v3, idv);
+
+#if 0
+                    v0[0] += frand(); v0[1] += frand(); v0[2] += frand(); v0[3] += frand();
+                    v1[0] += frand(); v1[1] += frand(); v1[2] += frand(); v1[3] += frand();
+                    v2[0] += frand(); v2[1] += frand(); v2[2] += frand(); v2[3] += frand();
+                    v3[0] += frand(); v3[1] += frand(); v3[2] += frand(); v3[3] += frand();
+#endif
+
+                    uint32x4_t q0 = vcvtq_u32_f32(v0);
+                    uint32x4_t q1 = vcvtq_u32_f32(v1);
+                    uint32x4_t q2 = vcvtq_u32_f32(v2);
+                    uint32x4_t q3 = vcvtq_u32_f32(v3);
+
+                    // store in qq as uint8_t
+                    vst1_u8(qq + l + 0, vmovn_u16(vcombine_u16(vmovn_u32(q0), vmovn_u32(q1))));
+                    vst1_u8(qq + l + 8, vmovn_u16(vcombine_u16(vmovn_u32(q2), vmovn_u32(q3))));
+                }
 
-            for (int b = 0; b < QB; b++) {
-                pb[i*nq*QB + s*QB + b] = pp[b];
+                for (int l = 0; l < gq_t_bits; l++) {
+                    for (int b = 0; b < QB; b++) {
+                        const uint64_t ql = qq[l];
+                        /*pp[b] |= qq[l] & (1 << b) ? (1ULL << l) : 0;*/
+                        pp[b] |= ((ql & (1 << b)) >> b) << l;
+                    }
+                }
             }
+#endif
+#endif
+            memcpy(pb + i*nq*QB + s*QB, pp, sizeof(pp));
         }
     }
 }
@@ -290,9 +515,6 @@ void quantize_2(const float * restrict src, char * restrict dst, int n, int k) {
 }
 
 void vec_dot_gq_2(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
-    float sumf[(QB + 1)*(QB + 1)];
-    memset(sumf, 0, sizeof(sumf));
-
     const int nb = quantize_2_blocks_per_row(n);
     const int nq = quantize_2_quants_per_block();
 
@@ -305,10 +527,9 @@ void vec_dot_gq_2(const int n, float * restrict s, const void * restrict x, cons
     const gq_quant_t * restrict pb0 = (const gq_quant_t *) (pd0 + nb);
     const gq_quant_t * restrict pb1 = (const gq_quant_t *) (pd1 + nb);
 
-#if 1
-    float s0[QB + 1];
-    float s1[QB + 1];
+    float sumf = 0.0;
 
+#if 1
     for (int i = 0; i < nb; i++) {
         const float m0 = GGML_GQ_TO_FP32(pm0[i]);
         const float d0 = GGML_GQ_TO_FP32(pd0[i]);
@@ -316,6 +537,99 @@ void vec_dot_gq_2(const int n, float * restrict s, const void * restrict x, cons
         const float m1 = GGML_GQ_TO_FP32(pm1[i]);
         const float d1 = GGML_GQ_TO_FP32(pd1[i]);
 
+#if QB == 4
+        int isum01 = 0;
+        int isum10 = 0;
+        int isum11 = 0;
+
+        for (int s = 0; s < nq; ++s) {
+            const gq_quant_t * restrict mm0 = pb0 + i*nq*QB + s*QB;
+            const gq_quant_t * restrict mm1 = pb1 + i*nq*QB + s*QB;
+
+#define bpcnt(x) __builtin_popcountll(x)
+            isum01 += (1 << 0)*(bpcnt(mm1[0]));
+            isum01 += (1 << 1)*(bpcnt(mm1[1]));
+            isum01 += (1 << 2)*(bpcnt(mm1[2]));
+            isum01 += (1 << 3)*(bpcnt(mm1[3]));
+
+            isum10 += (1 << 0)*(bpcnt(mm0[0]));
+            isum10 += (1 << 1)*(bpcnt(mm0[1]));
+            isum10 += (1 << 2)*(bpcnt(mm0[2]));
+            isum10 += (1 << 3)*(bpcnt(mm0[3]));
+
+            isum11 += (1 << 0)*(bpcnt(mm0[0] & mm1[0]));
+            isum11 += (1 << 1)*(bpcnt(mm0[0] & mm1[1]) + bpcnt(mm0[1] & mm1[0]));
+            isum11 += (1 << 2)*(bpcnt(mm0[0] & mm1[2]) + bpcnt(mm0[1] & mm1[1]) + bpcnt(mm0[2] & mm1[0]));
+            isum11 += (1 << 3)*(bpcnt(mm0[0] & mm1[3]) + bpcnt(mm0[1] & mm1[2]) + bpcnt(mm0[2] & mm1[1]) + bpcnt(mm0[3] & mm1[0]));
+            isum11 += (1 << 4)*(bpcnt(mm0[1] & mm1[3]) + bpcnt(mm0[2] & mm1[2]) + bpcnt(mm0[3] & mm1[1]));
+            isum11 += (1 << 5)*(bpcnt(mm0[2] & mm1[3]) + bpcnt(mm0[3] & mm1[2]));
+            isum11 += (1 << 6)*(bpcnt(mm0[3] & mm1[3]));
+#undef bpcnt
+        }
+
+        sumf += nq*gq_t_bits*(m0*m1) + isum01*(m0*d1) + isum10*(m1*d0) + isum11*(d0*d1);
+#elif QB == 3
+        int isum01 = 0;
+        int isum10 = 0;
+        int isum11 = 0;
+
+        for (int s = 0; s < nq; ++s) {
+            const gq_quant_t * restrict mm0 = pb0 + i*nq*QB + s*QB;
+            const gq_quant_t * restrict mm1 = pb1 + i*nq*QB + s*QB;
+
+#if gq_t_bits == 32
+#define bpcnt(x) __builtin_popcount(x)
+#else
+#define bpcnt(x) __builtin_popcountll(x)
+#endif
+            isum01 += (1 << 0)*(bpcnt(mm1[0]));
+            isum01 += (1 << 1)*(bpcnt(mm1[1]));
+            isum01 += (1 << 2)*(bpcnt(mm1[2]));
+
+            isum10 += (1 << 0)*(bpcnt(mm0[0]));
+            isum10 += (1 << 1)*(bpcnt(mm0[1]));
+            isum10 += (1 << 2)*(bpcnt(mm0[2]));
+
+            isum11 += (1 << 0)*(bpcnt(mm0[0] & mm1[0]));
+            isum11 += (1 << 1)*(bpcnt(mm0[0] & mm1[1]) + bpcnt(mm0[1] & mm1[0]));
+            isum11 += (1 << 2)*(bpcnt(mm0[0] & mm1[2]) + bpcnt(mm0[1] & mm1[1]) + bpcnt(mm0[2] & mm1[0]));
+            isum11 += (1 << 3)*(bpcnt(mm0[1] & mm1[2]) + bpcnt(mm0[2] & mm1[1]));
+            isum11 += (1 << 4)*(bpcnt(mm0[2] & mm1[2]));
+#undef bpcnt
+        }
+
+        sumf += nq*gq_t_bits*(m0*m1) + isum01*(m0*d1) + isum10*(m1*d0) + isum11*(d0*d1);
+#elif QB == 2
+        int isum01 = 0;
+        int isum10 = 0;
+        int isum11 = 0;
+
+        for (int s = 0; s < nq; ++s) {
+            const gq_quant_t * restrict mm0 = pb0 + i*nq*QB + s*QB;
+            const gq_quant_t * restrict mm1 = pb1 + i*nq*QB + s*QB;
+
+#if gq_t_bits == 32
+#define bpcnt(x) __builtin_popcount(x)
+#else
+#define bpcnt(x) __builtin_popcountll(x)
+#endif
+            isum01 += (1 << 0)*(bpcnt(mm1[0]));
+            isum01 += (1 << 1)*(bpcnt(mm1[1]));
+
+            isum10 += (1 << 0)*(bpcnt(mm0[0]));
+            isum10 += (1 << 1)*(bpcnt(mm0[1]));
+
+            isum11 += (1 << 0)*(bpcnt(mm0[0] & mm1[0]));
+            isum11 += (1 << 1)*(bpcnt(mm0[0] & mm1[1]) + bpcnt(mm0[1] & mm1[0]));
+            isum11 += (1 << 2)*(bpcnt(mm0[1] & mm1[1]));
+#undef bpcnt
+        }
+
+        sumf += nq*gq_t_bits*(m0*m1) + isum01*(m0*d1) + isum10*(m1*d0) + isum11*(d0*d1);
+#else
+        float s0[QB + 1];
+        float s1[QB + 1];
+
         s0[0] = m0;
         s1[0] = m1;
 
@@ -329,36 +643,17 @@ void vec_dot_gq_2(const int n, float * restrict s, const void * restrict x, cons
                 const gq_quant_t mm0 = q0 ? pb0[i*nq*QB + s*QB + q0 - 1] : -1ULL;
                 for (int q1 = 0; q1 < QB + 1; q1++) {
                     const gq_quant_t mm1 = q1 ? pb1[i*nq*QB + s*QB + q1 - 1] : -1ULL;
-                    sumf[q0*(QB + 1) + q1] += s0[q0]*s1[q1]*__builtin_popcountll(mm0 & mm1);
+                    sumf += s0[q0]*s1[q1]*__builtin_popcountll(mm0 & mm1);
                 }
             }
         }
+#endif
     }
 #else
-    // SIMD-ify with the assumptions:
-    // - nb is a multiple of 4
-    // - gq_scale_t is float
-    // - gq_quant_t is uint64_t
-    // - QB == 7
-    assert(nb % 4 == 0);
-
-#ifdef __ARM_NEON
-#else
-    // TODO
-#endif
-
+#error "not implemented"
 #endif
 
-    for (int q0 = 0; q0 < QB + 1; q0++) {
-        for (int q1 = 1; q1 < QB + 1; q1++) {
-            sumf[q0*(QB + 1)] += sumf[q0*(QB + 1) + q1];
-        }
-    }
-
-    *s = sumf[0];
-    for (int q0 = 1; q0 < QB + 1; q0++) {
-        *s += sumf[q0*(QB + 1)];
-    }
+    *s = sumf;
 }
 
 // use vec_dot_gq_2 to compute the dot product of two rows
@@ -384,83 +679,1904 @@ void mul_mat_gq_2(
     }
 }
 
-int main(int argc, const char ** argv) {
-    assert(sizeof(gq_quant_t)*8 == gq_t_bits);
+//
+// method 3
+// (does not work)
+//
 
-    float * src0 = (float *)malloc(sizeof(float)*M*K);
-    float * src1 = (float *)malloc(sizeof(float)*N*K);
-    float * dst  = (float *)malloc(sizeof(float)*M*N);
+static inline int quantize_3_blocks_per_row(int k) {
+    return k/QK;
+}
 
-    for (int i = 0; i < M*K; i++) {
-        src0[i] = rand() / (float)RAND_MAX;
-    }
+static inline int quantize_3_quants_per_block() {
+    return QK/gq_t_bits;
+}
 
-    for (int i = 0; i < N*K; i++) {
-        src1[i] = rand() / (float)RAND_MAX;
-    }
+static inline int quantize_3_row_size(int k) {
+    const int nb = quantize_3_blocks_per_row(k);
+    const int nq = quantize_3_quants_per_block();
+
+    return nb*(sizeof(gq_scale_t) + nq*QB*sizeof(gq_quant_t));
+}
 
-    void * src0_gq = calloc(1, quantize_2_row_size(K)*M);
-    void * src1_gq = calloc(1, quantize_2_row_size(K)*N);
+void quantize_3_row(const float * restrict src, void * restrict dst, int k) {
+    assert(k % QK == 0);
 
-    const size_t sizef16 = sizeof(ggml_fp16_t)*M*K + sizeof(ggml_fp16_t)*N*K;
-    const size_t sizegq  = quantize_2_row_size(K)*M + quantize_2_row_size(K)*N;
+    const int nb = quantize_3_blocks_per_row(k);
+    const int nq = quantize_3_quants_per_block();
 
-    printf("compression: %f\n", (float)sizegq/sizef16);
+    gq_scale_t * restrict pd = (gq_scale_t *) (dst);
+    gq_quant_t * restrict pb = (gq_quant_t *) (pd + nb);
 
-    int method = 0;
-    if (argc > 1) {
-        method = atoi(argv[1]);
-    }
+    gq_quant_t pp[QB];
 
-    // convert fp32 -> gq
-    {
-        const uint64_t t_start = get_time_us();
+    static const int32_t sh[32] = {
+        0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15,
+        16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
+    };
 
-        if (method == 1) {
-            quantize_1(src0, src0_gq, M, K);
-            quantize_1(src1, src1_gq, N, K);
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // abs max
+
+#ifdef __ARM_NEON
+        {
+            // min / max
+            //float32x4_t minv = vdupq_n_f32(FLT_MAX);
+            //float32x4_t maxv = vdupq_n_f32(-FLT_MAX);
+
+            //for (int l = 0; l < QK; l += 4) {
+            //    float32x4_t v = vld1q_f32(src + i*QK + l);
+            //    minv = vminq_f32(minv, v);
+            //    maxv = vmaxq_f32(maxv, v);
+            //}
+
+            //float32x2_t minv32 = vpmin_f32(vget_low_f32(minv), vget_high_f32(minv));
+            //float32x2_t maxv32 = vpmax_f32(vget_low_f32(maxv), vget_high_f32(maxv));
+
+            //min = MIN(vget_lane_f32(minv32, 0), vget_lane_f32(minv32, 1));
+            //max = MAX(vget_lane_f32(maxv32, 0), vget_lane_f32(maxv32, 1));
+
+            // abs max
+            float32x4_t amaxv = vdupq_n_f32(0.0f);
+
+            for (int l = 0; l < QK; l += 4) {
+                float32x4_t v = vld1q_f32(src + i*QK + l);
+                amaxv = vmaxq_f32(amaxv, vabsq_f32(v));
+            }
+
+            float32x2_t amaxv32 = vpmax_f32(vget_low_f32(amaxv), vget_high_f32(amaxv));
+
+            amax = MAX(vget_lane_f32(amaxv32, 0), vget_lane_f32(amaxv32, 1));
+        }
+#else
+        {
+            for (int l = 0; l < QK; l++) {
+                const float v = src[i*QK + l];
+                amax = MAX(amax, fabsf(v));
+            }
         }
+#endif
 
-        if (method == 2) {
-            quantize_2(src0, src0_gq, M, K);
-            quantize_2(src1, src1_gq, N, K);
+        const float d = amax / ((1 << (QB - 1)) - 1);
+        const float id = d ? 1.0/d : 0.0;
+
+        pd[i] = GGML_FP32_TO_GQ(d);
+
+        for (int s = 0; s < nq; ++s) {
+            memset(pp, 0, sizeof(pp));
+
+#if 0
+            for (int l = 0; l < gq_t_bits; l++) {
+                const   float v = src[i*QK + s*gq_t_bits + l];
+                const uint8_t q = v*id + frand();
+
+                for (int b = 0; b < QB; b++) {
+                    pp[b] |= q & (1 << b) ? (1ULL << l) : 0;
+                }
+            }
+#elif defined(__ARM_NEON)
+            {
+                uint32_t ppt[2*4*QB];
+
+                float32x4_t idv  = vdupq_n_f32(id);
+
+                assert(gq_t_bits == 64);
+
+                uint32x4_t p0[QB] = { vdupq_n_u32(0) };
+                uint32x4_t p1[QB] = { vdupq_n_u32(0) };
+
+                for (int l = 0; l < gq_t_bits; l += 16) {
+                    float32x4_t v0 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 0);
+                    float32x4_t v1 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 4);
+                    float32x4_t v2 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 8);
+                    float32x4_t v3 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 12);
+
+                    v0 = vmulq_f32(v0, idv);
+                    v1 = vmulq_f32(v1, idv);
+                    v2 = vmulq_f32(v2, idv);
+                    v3 = vmulq_f32(v3, idv);
+
+#if 1
+                    v0[0] += frand(); v0[1] += frand(); v0[2] += frand(); v0[3] += frand();
+                    v1[0] += frand(); v1[1] += frand(); v1[2] += frand(); v1[3] += frand();
+                    v2[0] += frand(); v2[1] += frand(); v2[2] += frand(); v2[3] += frand();
+                    v3[0] += frand(); v3[1] += frand(); v3[2] += frand(); v3[3] += frand();
+#endif
+
+                    uint32x4_t q0 = vcvtq_u32_f32(v0);
+                    uint32x4_t q1 = vcvtq_u32_f32(v1);
+                    uint32x4_t q2 = vcvtq_u32_f32(v2);
+                    uint32x4_t q3 = vcvtq_u32_f32(v3);
+
+                    for (int b = 0; b < QB; ++b) {
+                        uint32x4_t m = vdupq_n_u32(1 << b);
+                        uint32x4_t r = vdupq_n_u32(-b);
+
+                        if (l < 32) {
+                            p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q0, m), r), vld1q_s32(sh + l + 0)));
+                            p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q1, m), r), vld1q_s32(sh + l + 4)));
+                            p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q2, m), r), vld1q_s32(sh + l + 8)));
+                            p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q3, m), r), vld1q_s32(sh + l + 12)));
+                        } else {
+                            p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q0, m), r), vld1q_s32(sh + l - 32)));
+                            p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q1, m), r), vld1q_s32(sh + l - 28)));
+                            p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q2, m), r), vld1q_s32(sh + l - 24)));
+                            p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q3, m), r), vld1q_s32(sh + l - 20)));
+                        }
+                    }
+                }
+
+#if QB == 4
+                vst1q_u32((uint32_t *) ppt + 0,  p0[0]);
+                vst1q_u32((uint32_t *) ppt + 4,  p1[0]);
+                vst1q_u32((uint32_t *) ppt + 8,  p0[1]);
+                vst1q_u32((uint32_t *) ppt + 12, p1[1]);
+                vst1q_u32((uint32_t *) ppt + 16, p0[2]);
+                vst1q_u32((uint32_t *) ppt + 20, p1[2]);
+                vst1q_u32((uint32_t *) ppt + 24, p0[3]);
+                vst1q_u32((uint32_t *) ppt + 28, p1[3]);
+
+                pp[0] = (ppt[0]  | ppt[1]  | ppt[2]  | ppt[3] ) | ((uint64_t) (ppt[4]  | ppt[5]  | ppt[6]  | ppt[7]) ) << 32;
+                pp[1] = (ppt[8]  | ppt[9]  | ppt[10] | ppt[11]) | ((uint64_t) (ppt[12] | ppt[13] | ppt[14] | ppt[15])) << 32;
+                pp[2] = (ppt[16] | ppt[17] | ppt[18] | ppt[19]) | ((uint64_t) (ppt[20] | ppt[21] | ppt[22] | ppt[23])) << 32;
+                pp[3] = (ppt[24] | ppt[25] | ppt[26] | ppt[27]) | ((uint64_t) (ppt[28] | ppt[29] | ppt[30] | ppt[31])) << 32;
+#else
+                for (int q = 0; q < QB; ++q) {
+                    vst1q_u32((uint32_t *) ppt + 0,  p0[q]);
+                    vst1q_u32((uint32_t *) ppt + 4,  p1[q]);
+
+                    pp[q] = (ppt[0] | ppt[1] | ppt[2] | ppt[3]) | ((uint64_t) (ppt[4] | ppt[5] | ppt[6] | ppt[7])) << 32;
+                }
+#endif
+            }
+#endif
+            memcpy(pb + i*nq*QB + s*QB, pp, sizeof(pp));
         }
+    }
+}
 
-        const uint64_t t_end = get_time_us();
-        printf("convert time: %f ms / method = %d\n", (t_end - t_start) / 1000.0, method);
+// reimplementation of quantize_3 using quantize_3_row
+void quantize_3(const float * restrict src, char * restrict dst, int n, int k) {
+    assert(k % QK == 0);
+
+    for (int j = 0; j < n; j++) {
+        quantize_3_row(src + j*k, dst, k);
+        dst = (char *) dst + quantize_3_row_size(k);
     }
+}
 
-    const int nIter = 1;
+void vec_dot_gq_3(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
+    float sumf = 0.0f;
 
-    const clock_t start = clock();
-    const uint64_t start_us = get_time_us();
+    const int nb = quantize_3_blocks_per_row(n);
+    const int nq = quantize_3_quants_per_block();
 
-    double iM = 1.0/M;
-    double sum = 0.0f;
-    for (int i = 0; i < nIter; i++) {
-        if (method == 0) {
-            mul_mat_f32_naive(src0, src1, dst, M, N, K);
-        }
+    const gq_scale_t * restrict pd0 = (const gq_scale_t *) x;
+    const gq_scale_t * restrict pd1 = (const gq_scale_t *) y;
 
-        if (method == 1) {
-            mul_mat_gq_1(src0_gq, src1_gq, dst, M, N, K);
-        }
+    const gq_quant_t * restrict pb0 = (const gq_quant_t *) (pd0 + nb);
+    const gq_quant_t * restrict pb1 = (const gq_quant_t *) (pd1 + nb);
 
-        if (method == 2) {
-            mul_mat_gq_2(src0_gq, src1_gq, dst, M, N, K);
+#if 1
+    for (int i = 0; i < nb; i++) {
+        int isum = 0;
+
+#if QB == 4
+        for (int s = 0; s < nq; ++s) {
+            const gq_quant_t * restrict m0 = pb0 + i*nq*QB + s*QB;
+            const gq_quant_t * restrict m1 = pb1 + i*nq*QB + s*QB;
+
+            isum += (1 << 0)*(__builtin_popcountll(m0[0] & m1[0]));
+            isum += (1 << 1)*(__builtin_popcountll(m0[0] & m1[1]) + __builtin_popcountll(m0[1] & m1[0]));
+            isum += (1 << 2)*(__builtin_popcountll(m0[0] & m1[2]) + __builtin_popcountll(m0[1] & m1[1]) + __builtin_popcountll(m0[2] & m1[0]));
+            isum += (1 << 3)*(__builtin_popcountll(m0[0] & m1[3]) + __builtin_popcountll(m0[1] & m1[2]) + __builtin_popcountll(m0[2] & m1[1]) + __builtin_popcountll(m0[3] & m1[0]));
+            isum += (1 << 4)*(__builtin_popcountll(m0[1] & m1[3]) + __builtin_popcountll(m0[2] & m1[2]) + __builtin_popcountll(m0[3] & m1[1]));
+            isum += (1 << 5)*(__builtin_popcountll(m0[2] & m1[3]) + __builtin_popcountll(m0[3] & m1[2]));
+            isum += (1 << 6)*(__builtin_popcountll(m0[3] & m1[3]));
         }
-    }
+#else
+        for (int s = 0; s < nq; ++s) {
+            for (int q0 = 0; q0 < QB; q0++) {
+                const gq_quant_t mm0 = pb0[i*nq*QB + s*QB + q0];
+                for (int q1 = 0; q1 < QB; q1++) {
+                    const gq_quant_t mm1 = pb1[i*nq*QB + s*QB + q1];
+                    isum += (1 << (q0 + q1))*(__builtin_popcountll(mm0 & mm1));
+                }
+            }
+        }
+#endif
 
-    for (int i = 0; i < N; i++) {
-        sum += dst[i]*iM;
+        const float d0 = GGML_GQ_TO_FP32(pd0[i]);
+        const float d1 = GGML_GQ_TO_FP32(pd1[i]);
+
+        sumf += d0*d1*isum;
     }
+#else
+#ifdef __ARM_NEON
+    // gq_quant_t == uint64_t
+    for (int i = 0; i < nb; i += 4) {
+        int isum[4] = {0, 0, 0, 0};
+
+        for (int k = 0; k < 4; ++k) {
+            for (int s = 0; s < nq; ++s) {
+                const gq_quant_t * restrict m0 = pb0 + (i+k)*nq*QB + s*QB;
+                const gq_quant_t * restrict m1 = pb1 + (i+k)*nq*QB + s*QB;
+
+#if QB == 4
+#define bpcnt(x) __builtin_popcountll(x)
+                //isum[k] += (1ULL << 0)*(bpcnt(m0[0] & m1[0])) +
+                //           (1ULL << 1)*(bpcnt(m0[0] & m1[1]) + bpcnt(m0[1] & m1[0])) +
+                //           (1ULL << 2)*(bpcnt(m0[0] & m1[2]) + bpcnt(m0[1] & m1[1]) + bpcnt(m0[2] & m1[0])) +
+                //           (1ULL << 3)*(bpcnt(m0[0] & m1[3]) + bpcnt(m0[1] & m1[2]) + bpcnt(m0[2] & m1[1]) + bpcnt(m0[3] & m1[0])) +
+                //           (1ULL << 4)*(bpcnt(m0[1] & m1[3]) + bpcnt(m0[2] & m1[2]) + bpcnt(m0[3] & m1[1])) +
+                //           (1ULL << 5)*(bpcnt(m0[2] & m1[3]) + bpcnt(m0[3] & m1[2])) +
+                //           (1ULL << 6)*(bpcnt(m0[3] & m1[3]));
+#undef bpcnt
+
+                const uint8x8_t m00 = vld1_u8((const uint8_t *) (m0 + 0));
+                const uint8x8_t m01 = vld1_u8((const uint8_t *) (m0 + 1));
+                const uint8x8_t m02 = vld1_u8((const uint8_t *) (m0 + 2));
+                const uint8x8_t m03 = vld1_u8((const uint8_t *) (m0 + 3));
+
+                const uint8x8_t m10 = vld1_u8((const uint8_t *) (m1 + 0));
+                const uint8x8_t m11 = vld1_u8((const uint8_t *) (m1 + 1));
+                const uint8x8_t m12 = vld1_u8((const uint8_t *) (m1 + 2));
+                const uint8x8_t m13 = vld1_u8((const uint8_t *) (m1 + 3));
+
+                const uint8x8_t m00m10 = vand_u8(m00, m10);
+
+                const uint8x8_t m00m11 = vand_u8(m00, m11);
+                const uint8x8_t m01m10 = vand_u8(m01, m10);
+
+                const uint8x8_t m00m12 = vand_u8(m00, m12);
+                const uint8x8_t m01m11 = vand_u8(m01, m11);
+                const uint8x8_t m02m10 = vand_u8(m02, m10);
+
+                const uint8x8_t m00m13 = vand_u8(m00, m13);
+                const uint8x8_t m01m12 = vand_u8(m01, m12);
+                const uint8x8_t m02m11 = vand_u8(m02, m11);
+                const uint8x8_t m03m10 = vand_u8(m03, m10);
+
+                const uint8x8_t m01m13 = vand_u8(m01, m13);
+                const uint8x8_t m02m12 = vand_u8(m02, m12);
+                const uint8x8_t m03m11 = vand_u8(m03, m11);
+
+                const uint8x8_t m02m13 = vand_u8(m02, m13);
+                const uint8x8_t m03m12 = vand_u8(m03, m12);
+
+                const uint8x8_t m03m13 = vand_u8(m03, m13);
+
+#define bpcnt(x) vaddv_u8(vcnt_u8(x))
+                isum[k] += (1ULL << 0)*(bpcnt(m00m10)) +
+                           (1ULL << 1)*(bpcnt(m00m11) + bpcnt(m01m10)) +
+                           (1ULL << 2)*(bpcnt(m00m12) + bpcnt(m01m11) + bpcnt(m02m10)) +
+                           (1ULL << 3)*(bpcnt(m00m13) + bpcnt(m01m12) + bpcnt(m02m11) + bpcnt(m03m10)) +
+                           (1ULL << 4)*(bpcnt(m01m13) + bpcnt(m02m12) + bpcnt(m03m11)) +
+                           (1ULL << 5)*(bpcnt(m02m13) + bpcnt(m03m12)) +
+                           (1ULL << 6)*(bpcnt(m03m13));
+#undef bpcnt
+#else
+                for (int q0 = 0; q0 < QB; q0++) {
+                    const gq_quant_t mm0 = m0[q0];
+                    for (int q1 = 0; q1 < QB; q1++) {
+                        const gq_quant_t mm1 = m1[q1];
+                        isum[k] += (1ULL << (q0 + q1))*(__builtin_popcountll(mm0 & mm1));
+                    }
+                }
+#endif
+            }
+        }
 
-    {
-        const clock_t end = clock();
-        const uint64_t end_us = get_time_us();
-        printf("%s: elapsed ticks: %ld\n",  __func__, end - start);
-        printf("%s: elapsed us:    %d / %f ms\n",  __func__, (int)(end_us - start_us), (end_us - start_us) / 1000.0 / nIter);
+        int32x4_t isumv = vld1q_s32(isum);
+
+        float32x4_t d0v = vld1q_f32(pd0 + i);
+        float32x4_t d1v = vld1q_f32(pd1 + i);
+
+        float32x4_t sumfv = vmulq_f32(d0v, d1v);
+
+        sumfv = vmulq_f32(sumfv, vcvtq_f32_s32(isumv));
+        sumf += vaddvq_f32(sumfv);
     }
+#else
+#error "not implemented"
+#endif
+
+#endif
+    *s = sumf;
+}
+
+// use vec_dot_gq_3 to compute the dot product of two rows
+void mul_mat_gq_3(
+    const void * src0,
+    const void * src1, // transposed
+         float * dst,
+    int m, int n, int k) {
+    assert(k % QK == 0);
+
+    const int nb = quantize_3_blocks_per_row(k);
+    const int nq = quantize_3_quants_per_block();
+
+    for (int ir0 = 0; ir0 < m; ir0++) {
+        for (int ir1 = 0; ir1 < n; ir1++) {
+            vec_dot_gq_3(k, dst + ir1, src0, src1);
+            src1 = (const char *) src1 + quantize_3_row_size(k);
+        }
+        src0 = (const char *) src0 +   quantize_3_row_size(k);
+        src1 = (const char *) src1 - n*quantize_3_row_size(k);
+
+        dst = (float *) dst + n;
+    }
+}
+
+//
+// method 4
+// 4-bit quantization
+//
+
+static inline int quantize_4_blocks_per_row(int k) {
+    return k/QK;
+}
+
+static inline int quantize_4_row_size(int k) {
+    const int nb = quantize_4_blocks_per_row(k);
+
+    return nb*(2*sizeof(gq_scale_t) + QK/2);
+}
+
+void quantize_4_row(const float * restrict src, void * restrict dst, int k) {
+    assert(k % QK == 0);
+    assert(QB == 4);
+
+    const int nb = quantize_4_blocks_per_row(k);
+
+    gq_scale_t * restrict pm = (gq_scale_t *) (dst);
+    gq_scale_t * restrict pd = (gq_scale_t *) (pm + nb);
+    uint8_t    * restrict pb = (uint8_t *)    (pd + nb);
+
+    uint8_t pp[QK/2];
+
+    for (int i = 0; i < nb; i++) {
+        memset(pp, 0, sizeof(pp));
+
+        float min = FLT_MAX;
+        float max = -FLT_MAX;
+
+#if defined(__AVX2__)
+        {
+            assert(QK == 64);
+            const int QK8 = QK/8;
+
+            __m256 srcv[QK8];
+            __m256 minv[QK8];
+            __m256 maxv[QK8];
+
+            for (int l = 0; l < QK8; l++) {
+                srcv[l] = _mm256_loadu_ps(src + i*QK + 8*l);
+            }
+
+            for (int l = 0; l < QK8/2; l++) {
+                minv[2*l] = _mm256_min_ps(srcv[2*l], srcv[2*l+1]);
+                maxv[2*l] = _mm256_max_ps(srcv[2*l], srcv[2*l+1]);
+            }
+
+            for (int l = 0; l < QK8/4; l++) {
+                minv[4*l] = _mm256_min_ps(minv[4*l], minv[4*l+2]);
+                maxv[4*l] = _mm256_max_ps(maxv[4*l], maxv[4*l+2]);
+            }
+
+            for (int l = 0; l < QK8/8; l++) {
+                minv[8*l] = _mm256_min_ps(minv[8*l], minv[8*l+4]);
+                maxv[8*l] = _mm256_max_ps(maxv[8*l], maxv[8*l+4]);
+            }
+
+            //min = MIN(minv[0][0], MIN(minv[0][1], MIN(minv[0][2], MIN(minv[0][3], MIN(minv[0][4], MIN(minv[0][5], MIN(minv[0][6], minv[0][7])))))));
+            //max = MAX(maxv[0][0], MAX(maxv[0][1], MAX(maxv[0][2], MAX(maxv[0][3], MAX(maxv[0][4], MAX(maxv[0][5], MAX(maxv[0][6], maxv[0][7])))))));
+
+            const __m256 minv0_0 = _mm256_permute2f128_ps(minv[0], minv[0], 3);
+            const __m256 minv0_1 = _mm256_min_ps(minv[0], minv0_0);
+            const __m256 minv0_2 = _mm256_permute_ps(minv0_1, 0x4e);
+            const __m256 minv0_3 = _mm256_min_ps(minv0_1, minv0_2);
+            const __m256 minv0_4 = _mm256_permute_ps(minv0_3, 0xb1);
+            const __m256 minv0_5 = _mm256_min_ps(minv0_3, minv0_4);
+
+            const __m256 maxv0_0 = _mm256_permute2f128_ps(maxv[0], maxv[0], 3);
+            const __m256 maxv0_1 = _mm256_max_ps(maxv[0], maxv0_0);
+            const __m256 maxv0_2 = _mm256_permute_ps(maxv0_1, 0x4e);
+            const __m256 maxv0_3 = _mm256_max_ps(maxv0_1, maxv0_2);
+            const __m256 maxv0_4 = _mm256_permute_ps(maxv0_3, 0xb1);
+            const __m256 maxv0_5 = _mm256_max_ps(maxv0_3, maxv0_4);
+
+            min = _mm256_cvtss_f32(minv0_5);
+            max = _mm256_cvtss_f32(maxv0_5);
+
+            const float d = (max - min) / ((1 << QB) - 2);
+            const float id = d ? 1.0/d : 0.0;
+
+            pm[i] = GGML_FP32_TO_GQ(min);
+            pd[i] = GGML_FP32_TO_GQ(d);
+
+            const __m256 idv = _mm256_set1_ps(id);
+
+            for (int l = 0; l < QK/8; l++) {
+                __m256 v = _mm256_mul_ps(_mm256_sub_ps(srcv[l], _mm256_set1_ps(min)), idv);
+#if 0
+                v[0] += frand(); v[1] += frand(); v[2] += frand(); v[3] += frand();
+                v[4] += frand(); v[5] += frand(); v[6] += frand(); v[7] += frand();
+#endif
+
+                // convert to uint8
+                __m256i vi = _mm256_cvtps_epi32(v);
+
+                uint32_t vi_0 = _mm256_extract_epi32(vi, 0);
+                uint32_t vi_1 = _mm256_extract_epi32(vi, 1);
+                uint32_t vi_2 = _mm256_extract_epi32(vi, 2);
+                uint32_t vi_3 = _mm256_extract_epi32(vi, 3);
+
+                uint32_t vi_4 = _mm256_extract_epi32(vi, 4);
+                uint32_t vi_5 = _mm256_extract_epi32(vi, 5);
+                uint32_t vi_6 = _mm256_extract_epi32(vi, 6);
+                uint32_t vi_7 = _mm256_extract_epi32(vi, 7);
+
+                // convert to 4-bit, 2 consecutive packed into 1 byte
+                pp[4*l + 0] = vi_0 | (vi_1 << 4);
+                pp[4*l + 1] = vi_2 | (vi_3 << 4);
+                pp[4*l + 2] = vi_4 | (vi_5 << 4);
+                pp[4*l + 3] = vi_6 | (vi_7 << 4);
+
+                //printf("vi: %7d %7d %7d %7d %7d %7d %7d %7d\n", vi_0, vi_1, vi_2, vi_3, vi_4, vi_5, vi_6, vi_7);
+                //printf("v : %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]);
+            }
+
+            memcpy(pb + i*QK/2, pp, sizeof(pp));
+        }
+#elif defined(__ARM_NEON) && 0
+        {
+            // TODO
+        }
+#else
+        {
+            for (int l = 0; l < QK; l++) {
+                const float v = src[i*QK + l];
+                if (v < min) min = v;
+                if (v > max) max = v;
+            }
+
+            const float d = (max - min) / ((1 << QB) - 1);
+            const float id = d ? 1.0/d : 0.0;
+
+            pm[i] = GGML_FP32_TO_GQ(min);
+            pd[i] = GGML_FP32_TO_GQ(d);
+
+            for (int l = 0; l < QK; l++) {
+                const float v = (src[i*QK + l] - min) * id;
+                const uint8_t vi = (uint8_t) (v + frand());
+                pp[l/2] |= (vi & 0xf) << (4*(l & 1));
+            }
+
+            memcpy(pb + i*QK/2, pp, sizeof(pp));
+        }
+#endif
+        //printf("min %f max %f\n", min, max);
+    }
+}
+
+// reimplementation of quantize_4 using quantize_4_row
+void quantize_4(const float * restrict src, char * restrict dst, int n, int k) {
+    assert(k % QK == 0);
+
+    for (int j = 0; j < n; j++) {
+        quantize_4_row(src + j*k, dst, k);
+        dst = (char *) dst + quantize_4_row_size(k);
+    }
+}
+
+void vec_dot_gq_4(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
+    const int nb = quantize_4_blocks_per_row(n);
+
+    const gq_scale_t * restrict pm0 = (const gq_scale_t *) x;
+    const gq_scale_t * restrict pm1 = (const gq_scale_t *) y;
+
+    const gq_scale_t * restrict pd0 = pm0 + nb;
+    const gq_scale_t * restrict pd1 = pm1 + nb;
+
+    const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb);
+    const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb);
+
+    float sumf = 0.0;
+
+#if 0
+    // scalar
+    for (int i = 0; i < nb; i++) {
+        const float m0 = GGML_GQ_TO_FP32(pm0[i]);
+        const float d0 = GGML_GQ_TO_FP32(pd0[i]);
+
+        const float m1 = GGML_GQ_TO_FP32(pm1[i]);
+        const float d1 = GGML_GQ_TO_FP32(pd1[i]);
+
+        const uint8_t * restrict p0 = pb0 + i*QK/2;
+        const uint8_t * restrict p1 = pb1 + i*QK/2;
+
+        for (int j = 0; j < QK/2; j++) {
+            const uint8_t v0 = p0[j];
+            const uint8_t v1 = p1[j];
+
+            const float f0 = d0*(v0 & 0xf) + m0;
+            const float f1 = d0*(v0 >> 4)  + m0;
+
+            const float f2 = d1*(v1 & 0xf) + m1;
+            const float f3 = d1*(v1 >> 4)  + m1;
+
+            sumf += f0*f2 + f1*f3;
+        }
+    }
+#else
+#if defined(__AVX2__)
+#if QK == 64 && 0
+    __m256 sumv0 = _mm256_setzero_ps();
+    __m256 sumv1 = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; i++) {
+        const float m0 = GGML_GQ_TO_FP32(pm0[i]);
+        const float d0 = GGML_GQ_TO_FP32(pd0[i]);
+
+        const float m1 = GGML_GQ_TO_FP32(pm1[i]);
+        const float d1 = GGML_GQ_TO_FP32(pd1[i]);
+
+        const uint8_t * restrict p0 = pb0 + i*QK/2;
+        const uint8_t * restrict p1 = pb1 + i*QK/2;
+
+        const __m256 m0v = _mm256_set1_ps(m0);
+        const __m256 d0v = _mm256_set1_ps(d0);
+
+        const __m256 m1v = _mm256_set1_ps(m1);
+        const __m256 d1v = _mm256_set1_ps(d1);
+
+        const __m256i m4b = _mm256_set1_epi8(0xf);
+
+        __m256i v0 = _mm256_loadu_si256((__m256i *) p0);
+
+        //_mm_prefetch((const char *) (p0 + 32), _MM_HINT_T0);
+        //_mm_prefetch((const char *) (p1 + 32), _MM_HINT_T0);
+        //_mm_prefetch((const char *) (pm0 + i + 1), _MM_HINT_T0);
+        //_mm_prefetch((const char *) (pm1 + i + 1), _MM_HINT_T0);
+        //_mm_prefetch((const char *) (pd0 + i + 1), _MM_HINT_T0);
+        //_mm_prefetch((const char *) (pd1 + i + 1), _MM_HINT_T0);
+
+        __m256i v00 = _mm256_and_si256(v0, _mm256_set1_epi32(0x000000FF));
+        __m256i v01 = _mm256_srli_epi32(_mm256_and_si256(v0, _mm256_set1_epi32(0x0000FFFF)), 8);
+        __m256i v02 = _mm256_srli_epi32(_mm256_and_si256(v0, _mm256_set1_epi32(0x00FFFFFF)), 16);
+        __m256i v03 = _mm256_srli_epi32(v0, 24);
+
+        //////////////////////
+
+        //{
+        //    uint32_t vi_0 = _mm256_extract_epi32(v00, 0);
+        //    uint32_t vi_1 = _mm256_extract_epi32(v00, 1);
+        //    uint32_t vi_2 = _mm256_extract_epi32(v00, 2);
+        //    uint32_t vi_3 = _mm256_extract_epi32(v00, 3);
+        //    uint32_t vi_4 = _mm256_extract_epi32(v00, 4);
+        //    uint32_t vi_5 = _mm256_extract_epi32(v00, 5);
+        //    uint32_t vi_6 = _mm256_extract_epi32(v00, 6);
+        //    uint32_t vi_7 = _mm256_extract_epi32(v00, 7);
+        //    printf("v0: %7d %7d %7d %7d %7d %7d %7d %7d\n", vi_0, vi_1, vi_2, vi_3, vi_4, vi_5, vi_6, vi_7);
+        //    printf("p0: %7d %7d %7d %7d %7d %7d %7d %7d\n", p0[0], p0[4], p0[8], p0[12], p0[16], p0[20], p0[24], p0[28]);
+        //    printf("p1: %7d %7d %7d %7d %7d %7d %7d %7d\n", p0[1], p0[5], p0[9], p0[13], p0[17], p0[21], p0[25], p0[29]);
+        //    printf("p2: %7d %7d %7d %7d %7d %7d %7d %7d\n", p0[2], p0[6], p0[10], p0[14], p0[18], p0[22], p0[26], p0[30]);
+        //    printf("p3: %7d %7d %7d %7d %7d %7d %7d %7d\n", p0[3], p0[7], p0[11], p0[15], p0[19], p0[23], p0[27], p0[31]);
+        //}
+
+        // compute 32 x 4-bit values (low and high)
+        __m256i v00l = _mm256_and_si256(v00, m4b);
+        __m256i v01l = _mm256_and_si256(v01, m4b);
+        __m256i v02l = _mm256_and_si256(v02, m4b);
+        __m256i v03l = _mm256_and_si256(v03, m4b);
+
+        __m256i v00h = _mm256_srli_epi32(v00, 4);
+        __m256i v01h = _mm256_srli_epi32(v01, 4);
+        __m256i v02h = _mm256_srli_epi32(v02, 4);
+        __m256i v03h = _mm256_srli_epi32(v03, 4);
+
+        //{
+        //    uint32_t vi_0 = _mm256_extract_epi32(v00l, 0);
+        //    uint32_t vi_1 = _mm256_extract_epi32(v00l, 1);
+        //    uint32_t vi_2 = _mm256_extract_epi32(v00l, 2);
+        //    uint32_t vi_3 = _mm256_extract_epi32(v00l, 3);
+        //    uint32_t vi_4 = _mm256_extract_epi32(v00l, 4);
+        //    uint32_t vi_5 = _mm256_extract_epi32(v00l, 5);
+        //    uint32_t vi_6 = _mm256_extract_epi32(v00l, 6);
+        //    uint32_t vi_7 = _mm256_extract_epi32(v00l, 7);
+
+        //    printf("v0l: %7d %7d %7d %7d %7d %7d %7d %7d\n", vi_0, vi_1, vi_2, vi_3, vi_4, vi_5, vi_6, vi_7);
+
+        //    vi_0 = _mm256_extract_epi32(v00h, 0);
+        //    vi_1 = _mm256_extract_epi32(v00h, 1);
+        //    vi_2 = _mm256_extract_epi32(v00h, 2);
+        //    vi_3 = _mm256_extract_epi32(v00h, 3);
+        //    vi_4 = _mm256_extract_epi32(v00h, 4);
+        //    vi_5 = _mm256_extract_epi32(v00h, 5);
+        //    vi_6 = _mm256_extract_epi32(v00h, 6);
+        //    vi_7 = _mm256_extract_epi32(v00h, 7);
+
+        //    printf("v0h: %7d %7d %7d %7d %7d %7d %7d %7d\n", vi_0, vi_1, vi_2, vi_3, vi_4, vi_5, vi_6, vi_7);
+        //}
+
+        // convert to float
+        __m256 vf00l = _mm256_cvtepi32_ps(v00l);
+        __m256 vf01l = _mm256_cvtepi32_ps(v01l);
+        __m256 vf02l = _mm256_cvtepi32_ps(v02l);
+        __m256 vf03l = _mm256_cvtepi32_ps(v03l);
+
+        __m256 vf00h = _mm256_cvtepi32_ps(v00h);
+        __m256 vf01h = _mm256_cvtepi32_ps(v01h);
+        __m256 vf02h = _mm256_cvtepi32_ps(v02h);
+        __m256 vf03h = _mm256_cvtepi32_ps(v03h);
+
+        //{
+        //    printf("vf00l: %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", vf00l[0], vf00l[1], vf00l[2], vf00l[3], vf00l[4], vf00l[5], vf00l[6], vf00l[7]);
+        //    printf("vf01l: %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", vf01l[0], vf01l[1], vf01l[2], vf01l[3], vf01l[4], vf01l[5], vf01l[6], vf01l[7]);
+        //    printf("vf02l: %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", vf02l[0], vf02l[1], vf02l[2], vf02l[3], vf02l[4], vf02l[5], vf02l[6], vf02l[7]);
+        //    printf("vf03l: %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", vf03l[0], vf03l[1], vf03l[2], vf03l[3], vf03l[4], vf03l[5], vf03l[6], vf03l[7]);
+        //}
+
+        // multiply by scale and add offset
+        vf00l = _mm256_fmadd_ps(vf00l, d0v, m0v);
+        vf01l = _mm256_fmadd_ps(vf01l, d0v, m0v);
+        vf02l = _mm256_fmadd_ps(vf02l, d0v, m0v);
+        vf03l = _mm256_fmadd_ps(vf03l, d0v, m0v);
+
+        vf00h = _mm256_fmadd_ps(vf00h, d0v, m0v);
+        vf01h = _mm256_fmadd_ps(vf01h, d0v, m0v);
+        vf02h = _mm256_fmadd_ps(vf02h, d0v, m0v);
+        vf03h = _mm256_fmadd_ps(vf03h, d0v, m0v);
+
+        __m256i v1 = _mm256_loadu_si256((__m256i *) p1);
+
+        __m256i v10 = _mm256_and_si256(v1, _mm256_set1_epi32(0x000000FF));
+        __m256i v11 = _mm256_srli_epi32(_mm256_and_si256(v1, _mm256_set1_epi32(0x0000FFFF)), 8);
+        __m256i v12 = _mm256_srli_epi32(_mm256_and_si256(v1, _mm256_set1_epi32(0x00FFFFFF)), 16);
+        __m256i v13 = _mm256_srli_epi32(v1, 24);
+
+        __m256i v10l = _mm256_and_si256(v10, m4b);
+        __m256i v11l = _mm256_and_si256(v11, m4b);
+        __m256i v12l = _mm256_and_si256(v12, m4b);
+        __m256i v13l = _mm256_and_si256(v13, m4b);
+
+        __m256i v10h = _mm256_srli_epi32(v10, 4);
+        __m256i v11h = _mm256_srli_epi32(v11, 4);
+        __m256i v12h = _mm256_srli_epi32(v12, 4);
+        __m256i v13h = _mm256_srli_epi32(v13, 4);
+
+        __m256 vf10l = _mm256_cvtepi32_ps(v10l);
+        __m256 vf11l = _mm256_cvtepi32_ps(v11l);
+        __m256 vf12l = _mm256_cvtepi32_ps(v12l);
+        __m256 vf13l = _mm256_cvtepi32_ps(v13l);
+
+        __m256 vf10h = _mm256_cvtepi32_ps(v10h);
+        __m256 vf11h = _mm256_cvtepi32_ps(v11h);
+        __m256 vf12h = _mm256_cvtepi32_ps(v12h);
+        __m256 vf13h = _mm256_cvtepi32_ps(v13h);
+
+        vf10l = _mm256_fmadd_ps(vf10l, d1v, m1v);
+        vf11l = _mm256_fmadd_ps(vf11l, d1v, m1v);
+        vf12l = _mm256_fmadd_ps(vf12l, d1v, m1v);
+        vf13l = _mm256_fmadd_ps(vf13l, d1v, m1v);
+
+        vf10h = _mm256_fmadd_ps(vf10h, d1v, m1v);
+        vf11h = _mm256_fmadd_ps(vf11h, d1v, m1v);
+        vf12h = _mm256_fmadd_ps(vf12h, d1v, m1v);
+        vf13h = _mm256_fmadd_ps(vf13h, d1v, m1v);
+
+        // compute dot product
+        sumv0 = _mm256_fmadd_ps(vf00l, vf10l, sumv0);
+        sumv0 = _mm256_fmadd_ps(vf01l, vf11l, sumv0);
+        sumv0 = _mm256_fmadd_ps(vf02l, vf12l, sumv0);
+        sumv0 = _mm256_fmadd_ps(vf03l, vf13l, sumv0);
+
+        sumv1 = _mm256_fmadd_ps(vf00h, vf10h, sumv1);
+        sumv1 = _mm256_fmadd_ps(vf01h, vf11h, sumv1);
+        sumv1 = _mm256_fmadd_ps(vf02h, vf12h, sumv1);
+        sumv1 = _mm256_fmadd_ps(vf03h, vf13h, sumv1);
+    }
+
+    // accumulate (horizontal sum)
+    const __m256 vdot = _mm256_add_ps(sumv0, sumv1);
+    const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(vdot), _mm256_extractf128_ps(vdot, 1));
+    const __m128 t1 = _mm_hadd_ps(t0, t0);
+
+    sumf += _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
+#elif QK == 64 && 0
+    float sum00 = 0.0f;
+    float sum01 = 0.0f;
+    float sum10 = 0.0f;
+    float sum11 = 0.0f;
+
+    const __m256i m4b = _mm256_set1_epi8(0xf);
+
+    for (int i = 0; i < nb; i++) {
+        const float m0 = GGML_GQ_TO_FP32(pm0[i]);
+        const float d0 = GGML_GQ_TO_FP32(pd0[i]);
+
+        const float m1 = GGML_GQ_TO_FP32(pm1[i]);
+        const float d1 = GGML_GQ_TO_FP32(pd1[i]);
+
+        const uint8_t * restrict p0 = pb0 + i*QK/2;
+        const uint8_t * restrict p1 = pb1 + i*QK/2;
+
+        // 64 x 4
+        const __m256i v0 = _mm256_loadu_si256((__m256i *) p0);
+        const __m256i v1 = _mm256_loadu_si256((__m256i *) p1);
+
+        // 32 x 8
+        const __m256i v0l = _mm256_and_si256(v0, m4b);
+        const __m256i v1l = _mm256_and_si256(v1, m4b);
+
+        const __m256i v0h = _mm256_and_si256(_mm256_srli_epi16(v0, 4), m4b);
+        const __m256i v1h = _mm256_and_si256(_mm256_srli_epi16(v1, 4), m4b);
+
+        const __m256i pl = _mm256_maddubs_epi16(v0l, v1l);
+        const __m256i ph = _mm256_maddubs_epi16(v0h, v1h);
+
+        const __m256i p16 = _mm256_add_epi16(ph, pl);
+        const __m256i p = _mm256_madd_epi16(_mm256_set1_epi16(1), p16);
+
+        sum00 += m0*m1;
+        sum01 += m1*d0*(_mm256_hadd_epi8_gg(_mm256_add_epi8(v0l, v0h)));
+        sum10 += m0*d1*(_mm256_hadd_epi8_gg(_mm256_add_epi8(v1l, v1h)));
+        sum11 += d0*d1*(_mm256_hadd_epi32_gg(p));
+    }
+
+    sumf = 64.0*sum00 + sum01 + sum10 + sum11;
+#elif QK == 64 && 1 // this is the best when using min + d
+    float sum00 = 0.0f;
+
+    __m256 sum01 = _mm256_setzero_ps();
+    __m256 sum10 = _mm256_setzero_ps();
+    __m256 sum11 = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; i++) {
+        const float m0 = GGML_GQ_TO_FP32(pm0[i]);
+        const float d0 = GGML_GQ_TO_FP32(pd0[i]);
+
+        const float m1 = GGML_GQ_TO_FP32(pm1[i]);
+        const float d1 = GGML_GQ_TO_FP32(pd1[i]);
+
+        const uint8_t * restrict p0 = pb0 + i*QK/2;
+        const uint8_t * restrict p1 = pb1 + i*QK/2;
+
+        const __m256 m0v = _mm256_set1_ps(m0);
+        const __m256 d0v = _mm256_set1_ps(d0);
+
+        const __m256 m1v = _mm256_set1_ps(m1);
+        const __m256 d1v = _mm256_set1_ps(d1);
+
+        const __m256 m1d0v = _mm256_mul_ps(m1v, d0v);
+        const __m256 m0d1v = _mm256_mul_ps(m0v, d1v);
+        const __m256 d0d1v = _mm256_mul_ps(d0v, d1v);
+
+        const __m256i m4b = _mm256_set1_epi8(0xf);
+
+        // 64 x 4
+        const __m256i v0 = _mm256_loadu_si256((__m256i *) p0);
+        const __m256i v1 = _mm256_loadu_si256((__m256i *) p1);
+
+        // 32 x 8
+        const __m256i v0l = _mm256_and_si256(v0, m4b);
+        const __m256i v1l = _mm256_and_si256(v1, m4b);
+
+        const __m256i v0h = _mm256_and_si256(_mm256_srli_epi16(v0, 4), m4b);
+        const __m256i v1h = _mm256_and_si256(_mm256_srli_epi16(v1, 4), m4b);
+
+        const __m256i v0a = _mm256_add_epi8(v0l, v0h);
+        const __m256i v1a = _mm256_add_epi8(v1l, v1h);
+
+        const __m128i v0al = _mm256_extracti128_si256(v0a, 0);
+        const __m128i v0ah = _mm256_extracti128_si256(v0a, 1);
+
+        const __m128i v1al = _mm256_extracti128_si256(v1a, 0);
+        const __m128i v1ah = _mm256_extracti128_si256(v1a, 1);
+
+        const __m128i v0as = _mm_add_epi8(v0al, v0ah);
+        const __m128i v1as = _mm_add_epi8(v1al, v1ah);
+
+        const __m256i v0as_0 = _mm256_cvtepu8_epi32(v0as);
+        const __m256i v0as_1 = _mm256_cvtepu8_epi32(_mm_srli_si128(v0as, 8));
+
+        const __m256i v1as_0 = _mm256_cvtepu8_epi32(v1as);
+        const __m256i v1as_1 = _mm256_cvtepu8_epi32(_mm_srli_si128(v1as, 8));
+
+        const __m256i v0ass = _mm256_add_epi32(v0as_0, v0as_1);
+        const __m256i v1ass = _mm256_add_epi32(v1as_0, v1as_1);
+
+        const __m256 v0f = _mm256_cvtepi32_ps(v0ass);
+        const __m256 v1f = _mm256_cvtepi32_ps(v1ass);
+
+        const __m256i pl = _mm256_maddubs_epi16(v0l, v1l);
+        const __m256i ph = _mm256_maddubs_epi16(v0h, v1h);
+
+        const __m256i p16 = _mm256_add_epi16(ph, pl);
+        const __m256i p = _mm256_madd_epi16(_mm256_set1_epi16(1), p16);
+
+        sum00 += m0*m1;
+        sum01 = _mm256_fmadd_ps(m1d0v, v0f, sum01);
+        sum10 = _mm256_fmadd_ps(m0d1v, v1f, sum10);
+        sum11 = _mm256_fmadd_ps(d0d1v, _mm256_cvtepi32_ps(p), sum11);
+    }
+
+    sumf = 64.0*sum00 + _mm256_hadd_ps_gg(sum01) + _mm256_hadd_ps_gg(sum10) + _mm256_hadd_ps_gg(sum11);
+#endif
+#elif defined (__ARM_NEON)
+    float sum00 = 0.0f;
+    float sum01 = 0.0f;
+    float sum10 = 0.0f;
+    float sum11 = 0.0f;
+
+    for (int i = 0; i < nb; i++) {
+        const float m0 = GGML_GQ_TO_FP32(pm0[i]);
+        const float d0 = GGML_GQ_TO_FP32(pd0[i]);
+
+        const float m1 = GGML_GQ_TO_FP32(pm1[i]);
+        const float d1 = GGML_GQ_TO_FP32(pd1[i]);
+
+        const uint8_t * restrict p0 = pb0 + i*QK/2;
+        const uint8_t * restrict p1 = pb1 + i*QK/2;
+
+        const uint8x16_t m4b = vdupq_n_u8(0xf);
+
+        const uint8x16_t v0_0 = vld1q_u8(p0);
+        const uint8x16_t v0_1 = vld1q_u8(p0 + 16);
+        const uint8x16_t v1_0 = vld1q_u8(p1);
+        const uint8x16_t v1_1 = vld1q_u8(p1 + 16);
+
+        // and with 0xf
+        const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
+        const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
+        const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
+        const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
+
+        const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
+        const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
+        const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
+        const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
+
+        // dot product into uint16x8_t
+        const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
+        const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
+        const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
+        const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
+
+        const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
+        const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
+        const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
+        const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
+
+        const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h);
+        const uint16x8_t pl1 = vaddq_u16(pl1l, pl1h);
+        const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h);
+        const uint16x8_t ph1 = vaddq_u16(ph1l, ph1h);
+
+        const uint16x8_t pl = vaddq_u16(pl0, pl1);
+        const uint16x8_t ph = vaddq_u16(ph0, ph1);
+
+        sum00 += m0*m1;
+        sum01 += m1*d0*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h) + vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h));
+        sum10 += m0*d1*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h) + vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h));
+        //sum11 += d0*d1*(
+        //        vaddvq_u16(vaddq_u16(vaddq_u16(pl0l, pl0h), vaddq_u16(pl1l, pl1h))) +
+        //        vaddvq_u16(vaddq_u16(vaddq_u16(ph0l, ph0h), vaddq_u16(ph1l, ph1h))));
+        sum11 += d0*d1*vaddvq_u16(vaddq_u16(pl, ph));
+    }
+
+    sumf = 64.0*sum00 + sum01 + sum10 + sum11;
+#endif
+#endif
+
+    *s = sumf;
+}
+
+// use vec_dot_gq_4 to compute the dot product of two rows
+void mul_mat_gq_4(
+    const void * src0,
+    const void * src1, // transposed
+         float * dst,
+    int m, int n, int k) {
+    assert(k % QK == 0);
+
+    const int nb = quantize_4_blocks_per_row(k);
+
+    for (int ir0 = 0; ir0 < m; ir0++) {
+        for (int ir1 = 0; ir1 < n; ir1++) {
+            vec_dot_gq_4(k, dst + ir1, src0, src1);
+            src1 = (const char *) src1 + quantize_4_row_size(k);
+        }
+        src0 = (const char *) src0 +   quantize_4_row_size(k);
+        src1 = (const char *) src1 - n*quantize_4_row_size(k);
+
+        dst = (float *) dst + n;
+    }
+}
+
+//
+// method 5
+// 4-bit quantization (without min, only delta)
+//
+
+static inline int quantize_5_blocks_per_row(int k) {
+    return k/QK;
+}
+
+static inline int quantize_5_row_size(int k) {
+    const int nb = quantize_5_blocks_per_row(k);
+
+    return nb*(sizeof(gq_scale_t) + QK/2);
+}
+
+void quantize_5_row(const float * restrict src, void * restrict dst, int k) {
+    assert(k % QK == 0);
+    assert(QB == 4);
+
+    const int nb = quantize_5_blocks_per_row(k);
+
+    gq_scale_t * restrict pd = (gq_scale_t *) (dst);
+    uint8_t    * restrict pb = (uint8_t *)    (pd + nb);
+
+    uint8_t pp[QK/2];
+
+    for (int i = 0; i < nb; i++) {
+        memset(pp, 0, sizeof(pp));
+
+        float amax = 0.0f; // absolute max
+
+#if defined(__AVX2__)
+        {
+            assert(QK == 64);
+            const int QK8 = QK/8;
+
+            __m256 srcv [QK8];
+            __m256 asrcv[QK8];
+            __m256 amaxv[QK8];
+
+            for (int l = 0; l < QK8; l++) {
+                srcv[l]  = _mm256_loadu_ps(src + i*QK + 8*l);
+            }
+
+            for (int l = 0; l < QK8; l++) {
+                asrcv[l] = _mm256_and_ps(srcv[l], (__m256) _mm256_set1_epi32(0x7fffffff));
+            }
+
+
+            for (int l = 0; l < QK8/2; l++) {
+                amaxv[2*l] = _mm256_max_ps(asrcv[2*l], asrcv[2*l+1]);
+            }
+
+            for (int l = 0; l < QK8/4; l++) {
+                amaxv[4*l] = _mm256_max_ps(amaxv[4*l], amaxv[4*l+2]);
+            }
+
+            for (int l = 0; l < QK8/8; l++) {
+                amaxv[8*l] = _mm256_max_ps(amaxv[8*l], amaxv[8*l+4]);
+            }
+
+            //amax = MAX(amaxv[0][0], MAX(amaxv[0][1], MAX(amaxv[0][2], MAX(amaxv[0][3], MAX(amaxv[0][4], MAX(amaxv[0][5], MAX(amaxv[0][6], amaxv[0][7])))))));
+
+            const __m256 amaxv0_0 = _mm256_permute2f128_ps(amaxv[0], amaxv[0], 3);
+            const __m256 amaxv0_1 = _mm256_max_ps(amaxv[0], amaxv0_0);
+            const __m256 amaxv0_2 = _mm256_permute_ps(amaxv0_1, 0x4e);
+            const __m256 amaxv0_3 = _mm256_max_ps(amaxv0_1, amaxv0_2);
+            const __m256 amaxv0_4 = _mm256_permute_ps(amaxv0_3, 0xb1);
+            const __m256 amaxv0_5 = _mm256_max_ps(amaxv0_3, amaxv0_4);
+
+            amax = _mm256_cvtss_f32(amaxv0_5);
+
+            //printf("amax = %f\n", amax);
+
+            const float d = amax / ((1 << (QB - 1)) - 1);
+            const float id = d ? 1.0/d : 0.0;
+
+            pd[i] = GGML_FP32_TO_GQ(d);
+
+            const __m256 idv = _mm256_set1_ps(id);
+
+            for (int l = 0; l < QK/8; l++) {
+                __m256 v = _mm256_mul_ps(srcv[l], idv);
+#if 0
+                v[0] += frand(); v[1] += frand(); v[2] += frand(); v[3] += frand();
+                v[4] += frand(); v[5] += frand(); v[6] += frand(); v[7] += frand();
+#endif
+
+                // convert to int8
+                __m256i vi = _mm256_cvtps_epi32(v);
+                vi = _mm256_add_epi32(vi, _mm256_set1_epi32(8));
+
+                int32_t vi_0 = _mm256_extract_epi32(vi, 0);
+                int32_t vi_1 = _mm256_extract_epi32(vi, 1);
+                int32_t vi_2 = _mm256_extract_epi32(vi, 2);
+                int32_t vi_3 = _mm256_extract_epi32(vi, 3);
+
+                int32_t vi_4 = _mm256_extract_epi32(vi, 4);
+                int32_t vi_5 = _mm256_extract_epi32(vi, 5);
+                int32_t vi_6 = _mm256_extract_epi32(vi, 6);
+                int32_t vi_7 = _mm256_extract_epi32(vi, 7);
+
+                // convert to 4-bit, 2 consecutive packed into 1 byte
+                pp[4*l + 0] = vi_0 | (vi_1 << 4);
+                pp[4*l + 1] = vi_2 | (vi_3 << 4);
+                pp[4*l + 2] = vi_4 | (vi_5 << 4);
+                pp[4*l + 3] = vi_6 | (vi_7 << 4);
+
+                //printf("vi: %7d %7d %7d %7d %7d %7d %7d %7d\n", vi_0, vi_1, vi_2, vi_3, vi_4, vi_5, vi_6, vi_7);
+                ////printf("v : %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]);
+
+                assert(vi_0 >= 0 && vi_0 < 16);
+                assert(vi_1 >= 0 && vi_1 < 16);
+                assert(vi_2 >= 0 && vi_2 < 16);
+                assert(vi_3 >= 0 && vi_3 < 16);
+
+                assert(vi_4 >= 0 && vi_4 < 16);
+                assert(vi_5 >= 0 && vi_5 < 16);
+                assert(vi_6 >= 0 && vi_6 < 16);
+                assert(vi_7 >= 0 && vi_7 < 16);
+            }
+
+            memcpy(pb + i*QK/2, pp, sizeof(pp));
+        }
+#elif defined(__ARM_NEON) && 0
+        {
+            // TODO
+        }
+#else
+        {
+            for (int l = 0; l < QK; l++) {
+                const float v = src[i*QK + l];
+                amax = MAX(amax, fabsf(v));
+            }
+
+            const float d = amax / ((1 << (QB - 1)) - 1);
+            const float id = d ? 1.0/d : 0.0;
+
+            pd[i] = GGML_FP32_TO_GQ(d);
+
+            for (int l = 0; l < QK; l++) {
+                const float v = src[i*QK + l]*id;
+                const int8_t vi = ((int8_t) (round(v))) + 8;
+                assert(vi >= 0 && vi < 16);
+                pp[l/2] |= (vi & 0xf) << (4*(l & 1));
+            }
+
+            memcpy(pb + i*QK/2, pp, sizeof(pp));
+        }
+#endif
+        //printf("min %f max %f\n", min, max);
+    }
+}
+
+// reimplementation of quantize_5 using quantize_5_row
+void quantize_5(const float * restrict src, char * restrict dst, int n, int k) {
+    assert(k % QK == 0);
+
+    for (int j = 0; j < n; j++) {
+        quantize_5_row(src + j*k, dst, k);
+        dst = (char *) dst + quantize_5_row_size(k);
+    }
+}
+
+void vec_dot_gq_5(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
+    const int nb = quantize_5_blocks_per_row(n);
+
+    const gq_scale_t * restrict pd0 = (const gq_scale_t *) x;
+    const gq_scale_t * restrict pd1 = (const gq_scale_t *) y;
+
+    const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb);
+    const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb);
+
+    float sumf = 0.0;
+
+#if 0
+    // scalar
+    for (int i = 0; i < nb; i++) {
+        const float d0 = GGML_GQ_TO_FP32(pd0[i]);
+        const float d1 = GGML_GQ_TO_FP32(pd1[i]);
+
+        const uint8_t * restrict p0 = pb0 + i*QK/2;
+        const uint8_t * restrict p1 = pb1 + i*QK/2;
+
+        for (int j = 0; j < QK/2; j++) {
+            const uint8_t v0 = p0[j];
+            const uint8_t v1 = p1[j];
+
+            const float f0 = d0*((int8_t) (v0 & 0xf) - 8);
+            const float f1 = d0*((int8_t) (v0 >> 4)  - 8);
+
+            const float f2 = d1*((int8_t) (v1 & 0xf) - 8);
+            const float f3 = d1*((int8_t) (v1 >> 4)  - 8);
+
+            sumf += f0*f2 + f1*f3;
+        }
+    }
+#else
+#if defined(__AVX2__)
+#if QK == 64 && 1
+    __m256 sum11 = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; i++) {
+        const float d0 = GGML_GQ_TO_FP32(pd0[i]);
+        const float d1 = GGML_GQ_TO_FP32(pd1[i]);
+
+        const uint8_t * restrict p0 = pb0 + i*QK/2;
+        const uint8_t * restrict p1 = pb1 + i*QK/2;
+
+        const __m256 d0v = _mm256_set1_ps(d0);
+        const __m256 d1v = _mm256_set1_ps(d1);
+
+        const __m256 d0d1v = _mm256_mul_ps(d0v, d1v);
+
+        const __m256i m4b = _mm256_set1_epi8(0xf);
+
+        // 64 x 4
+        const __m256i v0 = _mm256_loadu_si256((__m256i *) p0);
+        const __m256i v1 = _mm256_loadu_si256((__m256i *) p1);
+
+        // 32 x 8
+        __m256i v0l = _mm256_and_si256(v0, m4b);
+        __m256i v1l = _mm256_and_si256(v1, m4b);
+
+        __m256i v0h = _mm256_and_si256(_mm256_srli_epi16(v0, 4), m4b);
+        __m256i v1h = _mm256_and_si256(_mm256_srli_epi16(v1, 4), m4b);
+
+        // sub 8
+        v0l = _mm256_sub_epi8(v0l, _mm256_set1_epi8(8));
+        v0h = _mm256_sub_epi8(v0h, _mm256_set1_epi8(8));
+
+        v1l = _mm256_sub_epi8(v1l, _mm256_set1_epi8(8));
+        v1h = _mm256_sub_epi8(v1h, _mm256_set1_epi8(8));
+
+        // abs
+        const __m256i v0la = _mm256_sign_epi8(v0l, v0l);
+        const __m256i v0ha = _mm256_sign_epi8(v0h, v0h);
+
+        // sign
+        const __m256i v1ls = _mm256_sign_epi8(v1l, v0l);
+        const __m256i v1hs = _mm256_sign_epi8(v1h, v0h);
+
+        const __m256i pl = _mm256_maddubs_epi16(v0la, v1ls);
+        const __m256i ph = _mm256_maddubs_epi16(v0ha, v1hs);
+
+        const __m256i p16 = _mm256_add_epi16(ph, pl);
+        const __m256i p = _mm256_madd_epi16(_mm256_set1_epi16(1), p16);
+
+        sum11 = _mm256_fmadd_ps(d0d1v, _mm256_cvtepi32_ps(p), sum11);
+    }
+
+    sumf = _mm256_hadd_ps_gg(sum11);
+#endif
+#elif defined (__ARM_NEON)
+    float sum11 = 0.0f;
+
+    //float32x4_t sum_0 = vdupq_n_f32(0.0f);
+    //float32x4_t sum_1 = vdupq_n_f32(0.0f);
+
+    //float16x8_t sum_0 = vdupq_n_f16(0.0f);
+    //float16x8_t sum_1 = vdupq_n_f16(0.0f);
+
+    for (int i = 0; i < nb; i++) {
+        const float d0 = GGML_GQ_TO_FP32(pd0[i]);
+        const float d1 = GGML_GQ_TO_FP32(pd1[i]);
+
+        //float32x4_t d0d1v = vdupq_n_f32(d0*d1);
+        //float16x8_t d0d1v = vdupq_n_f16(d0*d1);
+
+        const uint8_t * restrict p0 = pb0 + i*QK/2;
+        const uint8_t * restrict p1 = pb1 + i*QK/2;
+
+        const uint8x16_t m4b = vdupq_n_u8(0xf);
+        const int8x16_t  s8b = vdupq_n_s8(0x8);
+
+        const uint8x16_t v0_0 = vld1q_u8(p0);
+        const uint8x16_t v0_1 = vld1q_u8(p0 + 16);
+        const uint8x16_t v1_0 = vld1q_u8(p1);
+        const uint8x16_t v1_1 = vld1q_u8(p1 + 16);
+
+        // 4-bit -> 8-bit
+        const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
+        const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
+        const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
+        const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
+
+        const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
+        const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
+        const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
+        const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
+
+        // sub 8
+        const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
+        const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
+        const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
+        const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
+
+        const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
+        const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
+        const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
+        const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
+
+        // dot product into int16x8_t
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
+
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
+
+        const int16x8_t pl0 = vaddq_s16(pl0l, pl0h);
+        const int16x8_t pl1 = vaddq_s16(pl1l, pl1h);
+        const int16x8_t ph0 = vaddq_s16(ph0l, ph0h);
+        const int16x8_t ph1 = vaddq_s16(ph1l, ph1h);
+
+        const int16x8_t pl = vaddq_s16(pl0, pl1);
+        const int16x8_t ph = vaddq_s16(ph0, ph1);
+
+        //const int8x16_t pl0 = vmulq_s8(v0_0ls, v1_0ls);
+        //const int8x16_t pl1 = vmulq_s8(v0_1ls, v1_1ls);
+        //const int8x16_t ph0 = vmulq_s8(v0_0hs, v1_0hs);
+        //const int8x16_t ph1 = vmulq_s8(v0_1hs, v1_1hs);
+
+        //const int16x8_t pll = vaddl_s8(vget_low_s8(pl0),  vget_low_s8(pl1));
+        //const int16x8_t plh = vaddl_s8(vget_high_s8(pl0), vget_high_s8(pl1));
+        //const int16x8_t phl = vaddl_s8(vget_low_s8(ph0),  vget_low_s8(ph1));
+        //const int16x8_t phh = vaddl_s8(vget_high_s8(ph0), vget_high_s8(ph1));
+
+        //const int16x8_t pl = vaddq_s16(pll, plh);
+        //const int16x8_t ph = vaddq_s16(phl, phh);
+
+        const int16x8_t p = vaddq_s16(pl, ph);
+
+        // convert to float
+        //const float32x4_t pf0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (p)));
+        //const float32x4_t pf1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(p)));
+
+        // scalar
+        sum11 += d0*d1*vaddvq_s16(p);
+        //sum11 += d0*d1*(vaddvq_s16(pl) + vaddvq_s16(ph));
+        //sum11 += d0*d1*vaddvq_s16(vaddq_s16(pl, ph));
+        //sum11 += d0*d1*(vaddvq_s8(pl0) + vaddvq_s8(pl1) + vaddvq_s8(ph0) + vaddvq_s8(ph1));
+        //sum11 += d0*d1*(vaddvq_s16(pll) + vaddvq_s16(plh) + vaddvq_s16(phl) + vaddvq_s16(phh));
+
+        //sum_0 = vfmaq_f16(sum_0, d0d1v, vcvtq_f16_s16(p));
+        //sum_0 = vfmaq_f16(sum_0, d0d1v, vcvtq_f16_s16(pl));
+        //sum_1 = vfmaq_f16(sum_1, d0d1v, vcvtq_f16_s16(ph));
+
+        // vectorize
+        //sum_0 = vmlaq_f32(sum_0, d0d1v, pf0);
+        //sum_1 = vmlaq_f32(sum_1, d0d1v, pf1);
+    }
+
+    sumf = sum11;
+    //sumf = vaddvq_f32(sum_0) + vaddvq_f32(sum_1);
+    //sumf = sum_0[0] + sum_0[1] + sum_0[2] + sum_0[3] + sum_0[4] + sum_0[5] + sum_0[6] + sum_0[7];
+    //sum_0 = vaddq_f16(sum_0, sum_1);
+    //sumf = sum_0[0] + sum_0[1] + sum_0[2] + sum_0[3] + sum_0[4] + sum_0[5] + sum_0[6] + sum_0[7];
+#endif
+#endif
+
+    *s = sumf;
+}
+
+// use vec_dot_gq_5 to compute the dot product of two rows
+void mul_mat_gq_5(
+    const void * src0,
+    const void * src1, // transposed
+         float * dst,
+    int m, int n, int k) {
+    assert(k % QK == 0);
+
+    const int nb = quantize_5_blocks_per_row(k);
+
+    for (int ir0 = 0; ir0 < m; ir0++) {
+        for (int ir1 = 0; ir1 < n; ir1++) {
+            vec_dot_gq_5(k, dst + ir1, src0, src1);
+            src1 = (const char *) src1 + quantize_5_row_size(k);
+        }
+        src0 = (const char *) src0 +   quantize_5_row_size(k);
+        src1 = (const char *) src1 - n*quantize_5_row_size(k);
+
+        dst = (float *) dst + n;
+    }
+}
+
+//
+// method 6
+// same as 5 but with 32 element blocks
+//
+
+static inline int quantize_6_blocks_per_row(int k) {
+    return k/32;
+}
+
+static inline int quantize_6_row_size(int k) {
+    const int nb = quantize_6_blocks_per_row(k);
+
+    return nb*(sizeof(gq_scale_t) + 16);
+}
+
+void quantize_6_row(const float * restrict src, void * restrict dst, int k) {
+    assert(k % 32 == 0);
+    assert(QB == 4);
+
+    const int nb = quantize_6_blocks_per_row(k);
+
+    gq_scale_t * restrict pd = (gq_scale_t *) (dst);
+    uint8_t    * restrict pb = (uint8_t *)    (pd + nb);
+
+    uint8_t pp[16];
+
+    for (int i = 0; i < nb; i++) {
+        memset(pp, 0, sizeof(pp));
+
+        float amax = 0.0f; // absolute max
+
+#if defined(__AVX2__)
+        {
+            const int QK8 = 4;
+
+            __m256 srcv [QK8];
+            __m256 asrcv[QK8];
+            __m256 amaxv[QK8];
+
+            for (int l = 0; l < QK8; l++) {
+                srcv[l]  = _mm256_loadu_ps(src + i*32 + 8*l);
+            }
+
+            for (int l = 0; l < QK8; l++) {
+                asrcv[l] = _mm256_and_ps(srcv[l], (__m256) _mm256_set1_epi32(0x7fffffff));
+            }
+
+            for (int l = 0; l < QK8/2; l++) {
+                amaxv[2*l] = _mm256_max_ps(asrcv[2*l], asrcv[2*l+1]);
+            }
+
+            for (int l = 0; l < QK8/4; l++) {
+                amaxv[4*l] = _mm256_max_ps(amaxv[4*l], amaxv[4*l+2]);
+            }
+
+            const __m256 amaxv0_0 = _mm256_permute2f128_ps(amaxv[0], amaxv[0], 3);
+            const __m256 amaxv0_1 = _mm256_max_ps(amaxv[0], amaxv0_0);
+            const __m256 amaxv0_2 = _mm256_permute_ps(amaxv0_1, 0x4e);
+            const __m256 amaxv0_3 = _mm256_max_ps(amaxv0_1, amaxv0_2);
+            const __m256 amaxv0_4 = _mm256_permute_ps(amaxv0_3, 0xb1);
+            const __m256 amaxv0_5 = _mm256_max_ps(amaxv0_3, amaxv0_4);
+
+            amax = _mm256_cvtss_f32(amaxv0_5);
+
+            const float d = amax / ((1 << (QB - 1)) - 1);
+            const float id = d ? 1.0/d : 0.0;
+
+            pd[i] = GGML_FP32_TO_GQ(d);
+
+            const __m256 idv = _mm256_set1_ps(id);
+
+            for (int l = 0; l < 4; l++) {
+                __m256 v = _mm256_mul_ps(srcv[l], idv);
+
+                // convert to int8
+                __m256i vi = _mm256_cvtps_epi32(v);
+                vi = _mm256_add_epi32(vi, _mm256_set1_epi32(8));
+
+                int32_t vi_0 = _mm256_extract_epi32(vi, 0);
+                int32_t vi_1 = _mm256_extract_epi32(vi, 1);
+                int32_t vi_2 = _mm256_extract_epi32(vi, 2);
+                int32_t vi_3 = _mm256_extract_epi32(vi, 3);
+
+                int32_t vi_4 = _mm256_extract_epi32(vi, 4);
+                int32_t vi_5 = _mm256_extract_epi32(vi, 5);
+                int32_t vi_6 = _mm256_extract_epi32(vi, 6);
+                int32_t vi_7 = _mm256_extract_epi32(vi, 7);
+
+                // convert to 4-bit, 2 consecutive packed into 1 byte
+                pp[4*l + 0] = vi_0 | (vi_1 << 4);
+                pp[4*l + 1] = vi_2 | (vi_3 << 4);
+                pp[4*l + 2] = vi_4 | (vi_5 << 4);
+                pp[4*l + 3] = vi_6 | (vi_7 << 4);
+
+                assert(vi_0 >= 0 && vi_0 < 16);
+                assert(vi_1 >= 0 && vi_1 < 16);
+                assert(vi_2 >= 0 && vi_2 < 16);
+                assert(vi_3 >= 0 && vi_3 < 16);
+
+                assert(vi_4 >= 0 && vi_4 < 16);
+                assert(vi_5 >= 0 && vi_5 < 16);
+                assert(vi_6 >= 0 && vi_6 < 16);
+                assert(vi_7 >= 0 && vi_7 < 16);
+            }
+
+            memcpy(pb + i*16, pp, sizeof(pp));
+        }
+#elif defined(__ARM_NEON)
+        {
+            float32x4_t srcv [8];
+            float32x4_t asrcv[8];
+            float32x4_t amaxv[8];
+
+            for (int l = 0; l < 8; l++) srcv[l]  = vld1q_f32(src + i*32 + 4*l);
+            for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
+
+            for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
+            for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
+            for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
+
+            amax = MAX(
+                    MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)),
+                    MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3)));
+
+            const float d = amax / ((1 << 3) - 1);
+            const float id = d ? 1.0/d : 0.0;
+
+            pd[i] = GGML_FP32_TO_GQ(d);
+
+            for (int l = 0; l < 8; l++) {
+                const float32x4_t v = vmulq_n_f32(srcv[l], id);
+                const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
+                const int32x4_t vi = vcvtq_s32_f32(vf);
+
+                pp[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
+                pp[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
+            }
+
+            memcpy(pb + i*16, pp, sizeof(pp));
+        }
+#else
+        {
+            for (int l = 0; l < 32; l++) {
+                const float v = src[i*32 + l];
+                amax = MAX(amax, fabsf(v));
+            }
+
+            const float d = amax / ((1 << (QB - 1)) - 1);
+            const float id = d ? 1.0/d : 0.0;
+
+            pd[i] = GGML_FP32_TO_GQ(d);
+
+            for (int l = 0; l < 32; l++) {
+                const float v = src[i*32 + l]*id;
+                const int8_t vi = ((int8_t) (round(v))) + 8;
+                assert(vi >= 0 && vi < 16);
+                pp[l/2] |= (vi & 0xf) << (4*(l & 1));
+            }
+
+            memcpy(pb + i*16, pp, sizeof(pp));
+        }
+#endif
+        //printf("amax = %f\n", amax);
+    }
+}
+
+// reimplementation of quantize__6using quantize_6_row
+void quantize_6(const float * restrict src, char * restrict dst, int n, int k) {
+    assert(k % 32 == 0);
+
+    for (int j = 0; j < n; j++) {
+        quantize_6_row(src + j*k, dst, k);
+        dst = (char *) dst + quantize_6_row_size(k);
+    }
+}
+
+void vec_dot_gq_6(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
+    const int nb = quantize_6_blocks_per_row(n);
+
+    const gq_scale_t * restrict pd0 = (const gq_scale_t *) x;
+    const gq_scale_t * restrict pd1 = (const gq_scale_t *) y;
+
+    const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb);
+    const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb);
+
+    float sumf = 0.0;
+
+#if 0
+    // scalar
+    for (int i = 0; i < nb; i++) {
+        const float d0 = GGML_GQ_TO_FP32(pd0[i]);
+        const float d1 = GGML_GQ_TO_FP32(pd1[i]);
+
+        const uint8_t * restrict p0 = pb0 + i*16;
+        const uint8_t * restrict p1 = pb1 + i*16;
+
+        for (int j = 0; j < 16; j++) {
+            const uint8_t v0 = p0[j];
+            const uint8_t v1 = p1[j];
+
+            const float f0 = d0*((int8_t) (v0 & 0xf) - 8);
+            const float f1 = d0*((int8_t) (v0 >> 4)  - 8);
+
+            const float f2 = d1*((int8_t) (v1 & 0xf) - 8);
+            const float f3 = d1*((int8_t) (v1 >> 4)  - 8);
+
+            sumf += f0*f2 + f1*f3;
+        }
+    }
+#else
+#if defined(__AVX2__)
+    // TODO
+#elif defined (__ARM_NEON)
+#if 0
+    float sum0 = 0.0f;
+
+    for (int i = 0; i < nb; i++) {
+        const float d0 = GGML_GQ_TO_FP32(pd0[i]);
+        const float d1 = GGML_GQ_TO_FP32(pd1[i]);
+
+        //float32x4_t d0d1v = vdupq_n_f32(d0*d1);
+        //float16x8_t d0d1v = vdupq_n_f16(d0*d1);
+
+        const uint8_t * restrict p0 = pb0 + i*16;
+        const uint8_t * restrict p1 = pb1 + i*16;
+
+        const uint8x16_t m4b = vdupq_n_u8(0xf);
+        const int8x16_t  s8b = vdupq_n_s8(0x8);
+
+        const uint8x16_t v0_0 = vld1q_u8(p0);
+        const uint8x16_t v1_0 = vld1q_u8(p1);
+
+        // 4-bit -> 8-bit
+        const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
+        const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
+
+        const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
+        const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
+
+        // sub 8
+        const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
+        const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
+
+        const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
+        const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
+
+        // dot product into int16x8_t
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
+
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
+
+        const int16x8_t pl = vaddq_s16(pl0l, pl0h);
+        const int16x8_t ph = vaddq_s16(ph0l, ph0h);
+
+        const int16x8_t p = vaddq_s16(pl, ph);
+
+        // scalar
+        sum0 += d0*d1*vaddvq_s16(p);
+    }
+
+    sumf = sum0;
+#elif 1 // this is a bit faster than the above
+    float sum0 = 0.0f;
+    float sum1 = 0.0f;
+
+    for (int i = 0; i < nb; i += 2) {
+        const float d0_0 = GGML_GQ_TO_FP32(pd0[i + 0]);
+        const float d1_0 = GGML_GQ_TO_FP32(pd1[i + 0]);
+        const float d0_1 = GGML_GQ_TO_FP32(pd0[i + 1]);
+        const float d1_1 = GGML_GQ_TO_FP32(pd1[i + 1]);
+
+        const uint8_t * restrict p0 = pb0 + i*16;
+        const uint8_t * restrict p1 = pb1 + i*16;
+
+        const uint8x16_t m4b = vdupq_n_u8(0xf);
+        const int8x16_t  s8b = vdupq_n_s8(0x8);
+
+        const uint8x16_t v0_0 = vld1q_u8(p0);
+        const uint8x16_t v0_1 = vld1q_u8(p0 + 16);
+        const uint8x16_t v1_0 = vld1q_u8(p1);
+        const uint8x16_t v1_1 = vld1q_u8(p1 + 16);
+
+        // 4-bit -> 8-bit
+        const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
+        const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
+
+        const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
+        const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
+
+        const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
+        const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
+
+        const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
+        const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
+
+        // sub 8
+        const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
+        const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
+
+        const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
+        const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
+
+        const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
+        const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
+
+        const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
+        const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
+
+        // dot product into int16x8_t
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
+
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
+
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
+
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
+
+        const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
+        const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
+
+        const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
+        const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
+
+        const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
+        const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
+
+        // scalar
+        sum0 += d0_0*d1_0*vaddvq_s16(p_0);
+        sum1 += d0_1*d1_1*vaddvq_s16(p_1);
+    }
+
+    sumf = sum0 + sum1;
+#endif
+#endif
+#endif
+
+    *s = sumf;
+}
+
+// use vec_dot_gq_6 to compute the dot product of two rows
+void mul_mat_gq_6(
+    const void * src0,
+    const void * src1, // transposed
+         float * dst,
+    int m, int n, int k) {
+    assert(k % 32 == 0);
+
+    const int nb = quantize_6_blocks_per_row(k);
+
+    for (int ir0 = 0; ir0 < m; ir0++) {
+        for (int ir1 = 0; ir1 < n; ir1++) {
+            vec_dot_gq_6(k, dst + ir1, src0, src1);
+            src1 = (const char *) src1 + quantize_6_row_size(k);
+        }
+        src0 = (const char *) src0 +   quantize_6_row_size(k);
+        src1 = (const char *) src1 - n*quantize_6_row_size(k);
+
+        dst = (float *) dst + n;
+    }
+}
+
+int main(int argc, const char ** argv) {
+    assert(sizeof(gq_quant_t)*8 == gq_t_bits);
+
+    // needed to initialize f16 tables
+    {
+        struct ggml_init_params params = { 0, NULL };
+        struct ggml_context * ctx = ggml_init(params);
+        ggml_free(ctx);
+    }
+
+    int method = 0;
+    if (argc > 1) {
+        method = atoi(argv[1]);
+    }
+
+    float * src0 = (float *)malloc(sizeof(float)*M*K);
+    float * src1 = (float *)malloc(sizeof(float)*N*K);
+    float * dst  = (float *)malloc(sizeof(float)*M*N);
+
+    // allocate aligned memory
+    //float * src0 = (float *)aligned_alloc(32, sizeof(float)*M*K);
+    //float * src1 = (float *)aligned_alloc(32, sizeof(float)*N*K);
+    //float * dst  = (float *)aligned_alloc(32, sizeof(float)*M*N);
+
+    for (int i = 0; i < M*K; i++) {
+        src0[i] = 0.8 - rand() / (float)RAND_MAX;
+        /*src0[i] = rand() / (float)RAND_MAX;*/
+        /*src0[i] = i % 2;*/
+    }
+
+    for (int i = 0; i < N*K; i++) {
+        src1[i] = 0.8 - rand() / (float)RAND_MAX;
+        /*src1[i] = rand() / (float)RAND_MAX;*/
+        /*src1[i] = i % 3;*/
+    }
+
+    void * src0_gq = NULL;
+    void * src1_gq = NULL;
+
+    size_t sizegq = 0;
+
+    {
+        if (method == 1) {
+            src0_gq = calloc(1, quantize_1_row_size(K)*M);
+            src1_gq = calloc(1, quantize_1_row_size(K)*N);
+
+            sizegq  = quantize_1_row_size(K)*M + quantize_1_row_size(K)*N;
+        }
+
+        if (method == 2) {
+            src0_gq = calloc(1, quantize_2_row_size(K)*M);
+            src1_gq = calloc(1, quantize_2_row_size(K)*N);
+
+            sizegq  = quantize_2_row_size(K)*M + quantize_2_row_size(K)*N;
+        }
+
+        if (method == 3) {
+            src0_gq = calloc(1, quantize_3_row_size(K)*M);
+            src1_gq = calloc(1, quantize_3_row_size(K)*N);
+
+            sizegq  = quantize_3_row_size(K)*M + quantize_3_row_size(K)*N;
+        }
+
+        if (method == 4) {
+            src0_gq = calloc(1, quantize_4_row_size(K)*M);
+            src1_gq = calloc(1, quantize_4_row_size(K)*N);
+
+            sizegq  = quantize_4_row_size(K)*M + quantize_4_row_size(K)*N;
+        }
+
+        if (method == 5) {
+            src0_gq = calloc(1, quantize_5_row_size(K)*M);
+            src1_gq = calloc(1, quantize_5_row_size(K)*N);
+
+            sizegq  = quantize_5_row_size(K)*M + quantize_5_row_size(K)*N;
+        }
+
+        if (method == 6) {
+            src0_gq = calloc(1, quantize_6_row_size(K)*M);
+            src1_gq = calloc(1, quantize_6_row_size(K)*N);
+
+            sizegq  = quantize_6_row_size(K)*M + quantize_6_row_size(K)*N;
+        }
+    }
+
+    const size_t sizef16 = sizeof(ggml_fp16_t)*M*K + sizeof(ggml_fp16_t)*N*K;
+
+    printf("compression: %f\n", (float)sizegq/sizef16);
+
+    // convert fp32 -> gq
+    {
+        const uint64_t t_start = get_time_us();
+
+        if (method == 1) {
+            quantize_1(src0, src0_gq, M, K);
+            quantize_1(src1, src1_gq, N, K);
+        }
+
+        if (method == 2) {
+            quantize_2(src0, src0_gq, M, K);
+            quantize_2(src1, src1_gq, N, K);
+        }
+
+        if (method == 3) {
+            quantize_3(src0, src0_gq, M, K);
+            quantize_3(src1, src1_gq, N, K);
+        }
+
+        if (method == 4) {
+            quantize_4(src0, src0_gq, M, K);
+            quantize_4(src1, src1_gq, N, K);
+        }
+
+        if (method == 5) {
+            quantize_5(src0, src0_gq, M, K);
+            quantize_5(src1, src1_gq, N, K);
+        }
+
+        if (method == 6) {
+            quantize_6(src0, src0_gq, M, K);
+            quantize_6(src1, src1_gq, N, K);
+        }
+
+        const uint64_t t_end = get_time_us();
+        printf("convert time: %f ms / method = %d\n", (t_end - t_start) / 1000.0, method);
+    }
+
+    for (int i = 0; i < 16; ++i) {
+        printf("%f %f\n", src0[i], src1[i]);
+    }
+
+    const int nIter = 1;
+
+    const clock_t start = clock();
+    const uint64_t start_us = get_time_us();
+
+    double iM = 1.0/M;
+    double sum = 0.0f;
+    for (int i = 0; i < nIter; i++) {
+        if (method == 0) {
+            mul_mat_f32_naive(src0, src1, dst, M, N, K);
+        }
+
+        if (method == 1) {
+            mul_mat_gq_1(src0_gq, src1_gq, dst, M, N, K);
+        }
+
+        if (method == 2) {
+            mul_mat_gq_2(src0_gq, src1_gq, dst, M, N, K);
+        }
+
+        if (method == 3) {
+            mul_mat_gq_3(src0_gq, src1_gq, dst, M, N, K);
+        }
+
+        if (method == 4) {
+            mul_mat_gq_4(src0_gq, src1_gq, dst, M, N, K);
+        }
+
+        if (method == 5) {
+            mul_mat_gq_5(src0_gq, src1_gq, dst, M, N, K);
+        }
+
+        if (method == 6) {
+            mul_mat_gq_6(src0_gq, src1_gq, dst, M, N, K);
+        }
+    }
+
+    for (int i = 0; i < N; i++) {
+        sum += dst[i]*iM;
+    }
+
+    {
+        const clock_t end = clock();
+        const uint64_t end_us = get_time_us();
+        printf("%s: elapsed ticks: %ld\n",  __func__, end - start);
+        printf("%s: elapsed us:    %d / %f ms\n",  __func__, (int)(end_us - start_us), (end_us - start_us) / 1000.0 / nIter);
+    }
+
+#if 0
+    // print src0
+    printf("src0:\n");
+    for (int i = 0; i < M; i++) {
+        for (int j = 0; j < K; j++) {
+            printf("%4.1f ", src0[i*K+j]);
+        }
+        printf("\n");
+    }
+
+    // print src1
+    printf("src1:\n");
+    for (int i = 0; i < N; i++) {
+        for (int j = 0; j < K; j++) {
+            printf("%4.1f ", src1[i*K+j]);
+        }
+        printf("\n");
+    }
+
+    printf("dst:\n");
+    for (int i = 0; i < M; i++) {
+        for (int j = 0; j < N; j++) {
+            printf("%4.1f ", dst[i*N+j]);
+        }
+        printf("\n");
+    }
+#endif
 
     printf("%f\n", sum);
 
@@ -468,8 +2584,8 @@ int main(int argc, const char ** argv) {
     free(src1);
     free(dst);
 
-    free(src0_gq);
-    free(src1_gq);
+    if (src0_gq) free(src0_gq);
+    if (src1_gq) free(src1_gq);
 
     return 0;
 }