From: Georgi Gerganov Date: Thu, 27 Apr 2023 15:31:53 +0000 (+0300) Subject: ggml : sync llama.cpp (Q5_0 + Q5_1) + refactor examples quantization X-Git-Tag: upstream/0.0.1642~1517 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=abea4b7609c14b837015ab625e3ac36c4708dd03;p=pkg%2Fggml%2Fsources%2Fggml ggml : sync llama.cpp (Q5_0 + Q5_1) + refactor examples quantization --- diff --git a/examples/common-ggml.cpp b/examples/common-ggml.cpp index 60aaedaf..6deee198 100644 --- a/examples/common-ggml.cpp +++ b/examples/common-ggml.cpp @@ -1,26 +1,86 @@ #include "common-ggml.h" -#include "ggml.h" - #include +static const std::map GGML_FTYPE_MAP = { + {"q4_0", GGML_FTYPE_MOSTLY_Q4_0}, + {"q4_1", GGML_FTYPE_MOSTLY_Q4_1}, + {"q4_2", GGML_FTYPE_MOSTLY_Q4_2}, + {"q4_3", GGML_FTYPE_MOSTLY_Q4_3}, + {"q5_0", GGML_FTYPE_MOSTLY_Q5_0}, + {"q5_1", GGML_FTYPE_MOSTLY_Q5_1}, + {"q8_0", GGML_FTYPE_MOSTLY_Q8_0}, +}; + +void ggml_print_ftypes(FILE * fp) { + for (auto it = GGML_FTYPE_MAP.begin(); it != GGML_FTYPE_MAP.end(); it++) { + fprintf(fp, " type = \"%s\" or %d\n", it->first.c_str(), it->second); + } +} + +enum ggml_ftype ggml_parse_ftype(const char * str) { + enum ggml_ftype ftype; + if (str[0] == 'q') { + const auto it = GGML_FTYPE_MAP.find(str); + if (it == GGML_FTYPE_MAP.end()) { + fprintf(stderr, "%s: unknown ftype '%s'\n", __func__, str); + return GGML_FTYPE_UNKNOWN; + } + ftype = it->second; + } else { + ftype = (enum ggml_ftype) atoi(str); + } + + return ftype; +} + +enum ggml_type ggml_ftype_to_ggml_type(const enum ggml_ftype ftype) { + ggml_type wtype = GGML_TYPE_COUNT; + + switch (ftype) { + case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break; + case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break; + case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break; + case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break; + case GGML_FTYPE_MOSTLY_Q4_2: wtype = GGML_TYPE_Q4_2; break; + case GGML_FTYPE_MOSTLY_Q4_3: wtype = GGML_TYPE_Q4_3; break; + case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break; + case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break; + case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break; + case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; + case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; + } + + if (wtype == GGML_TYPE_COUNT) { + fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype); + } + + return wtype; +} + bool ggml_common_quantize_0( std::ifstream & finp, std::ofstream & fout, - const ggml_mtype mtype, + const ggml_ftype ftype, const std::vector & to_quant, const std::vector & to_skip) { ggml_type qtype = GGML_TYPE_F32; - switch (mtype) { - case 2: qtype = GGML_TYPE_Q4_0; break; - case 3: qtype = GGML_TYPE_Q4_1; break; - case 5: qtype = GGML_TYPE_Q4_2; break; - case 6: qtype = GGML_TYPE_Q4_3; break; - default: + switch (ftype) { + case GGML_FTYPE_MOSTLY_Q4_0: qtype = GGML_TYPE_Q4_0; break; + case GGML_FTYPE_MOSTLY_Q4_1: qtype = GGML_TYPE_Q4_1; break; + case GGML_FTYPE_MOSTLY_Q4_2: qtype = GGML_TYPE_Q4_2; break; + case GGML_FTYPE_MOSTLY_Q4_3: qtype = GGML_TYPE_Q4_3; break; + case GGML_FTYPE_MOSTLY_Q5_0: qtype = GGML_TYPE_Q5_0; break; + case GGML_FTYPE_MOSTLY_Q5_1: qtype = GGML_TYPE_Q5_1; break; + case GGML_FTYPE_MOSTLY_Q8_0: qtype = GGML_TYPE_Q8_0; break; + case GGML_FTYPE_UNKNOWN: + case GGML_FTYPE_ALL_F32: + case GGML_FTYPE_MOSTLY_F16: + case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: { - fprintf(stderr, "%s: invalid model type %d\n", __func__, mtype); + fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype); return false; } }; @@ -127,7 +187,7 @@ bool ggml_common_quantize_0( size_t cur_size = 0; std::vector hist_cur(1 << 4, 0); - switch (ttype) { + switch ((ggml_type) ttype) { case GGML_TYPE_Q4_0: { cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); @@ -144,7 +204,25 @@ bool ggml_common_quantize_0( { cur_size = ggml_quantize_q4_3(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); } break; - default: + case GGML_TYPE_Q5_0: + { + cur_size = ggml_quantize_q5_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); + } break; + case GGML_TYPE_Q5_1: + { + cur_size = ggml_quantize_q5_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); + } break; + case GGML_TYPE_Q8_0: + { + cur_size = ggml_quantize_q8_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); + } break; + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_Q8_1: + case GGML_TYPE_COUNT: { fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype)); return false; @@ -173,7 +251,7 @@ bool ggml_common_quantize_0( } printf("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0); - printf("%s: quant size = %8.2f MB | mtype = %d (%s)\n", __func__, total_size_new/1024.0/1024.0, mtype, ggml_type_name(qtype)); + printf("%s: quant size = %8.2f MB | ftype = %d (%s)\n", __func__, total_size_new/1024.0/1024.0, ftype, ggml_type_name(qtype)); { int64_t sum_all = 0; diff --git a/examples/common-ggml.h b/examples/common-ggml.h index da839bf9..377a7fdb 100644 --- a/examples/common-ggml.h +++ b/examples/common-ggml.h @@ -1,23 +1,37 @@ #pragma once +#include "ggml.h" + +#include #include #include #include // model file types -enum ggml_mtype { - GGML_MTYPE_ALL_F32 = 0, - GGML_MTYPE_MOSTLY_F16 = 1, // except 1d tensors - GGML_MTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors - GGML_MTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors - GGML_MTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 - GGML_MTYPE_MOSTLY_Q4_2 = 5, // except 1d tensors - GGML_MTYPE_MOSTLY_Q4_3 = 6, // except 1d tensors +enum ggml_ftype { + GGML_FTYPE_UNKNOWN = -1, + GGML_FTYPE_ALL_F32 = 0, + GGML_FTYPE_MOSTLY_F16 = 1, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 + GGML_FTYPE_MOSTLY_Q4_2 = 5, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_3 = 6, // except 1d tensors + GGML_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors + GGML_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors + GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors }; +void ggml_print_ftypes(FILE * fp = stderr); + +enum ggml_ftype ggml_parse_ftype(const char * str); + +// TODO: temporary +enum ggml_type ggml_ftype_to_ggml_type(const enum ggml_ftype ftype); + bool ggml_common_quantize_0( std::ifstream & finp, std::ofstream & fout, - const ggml_mtype mtype, + const ggml_ftype ftype, const std::vector & to_quant, const std::vector & to_skip); diff --git a/examples/gpt-2/CMakeLists.txt b/examples/gpt-2/CMakeLists.txt index c1917c13..1d9bcdd8 100644 --- a/examples/gpt-2/CMakeLists.txt +++ b/examples/gpt-2/CMakeLists.txt @@ -3,7 +3,7 @@ set(TEST_TARGET gpt-2) add_executable(${TEST_TARGET} main.cpp) -target_link_libraries(${TEST_TARGET} PRIVATE ggml common) +target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml) # # gpt-2-quantize diff --git a/examples/gpt-2/main.cpp b/examples/gpt-2/main.cpp index 519f901a..f90ee67c 100644 --- a/examples/gpt-2/main.cpp +++ b/examples/gpt-2/main.cpp @@ -1,6 +1,7 @@ #include "ggml/ggml.h" #include "common.h" +#include "common-ggml.h" #include #include @@ -20,7 +21,7 @@ struct gpt2_hparams { int32_t n_embd = 768; int32_t n_head = 12; int32_t n_layer = 12; - int32_t f16 = 1; + int32_t ftype = 1; }; struct gpt2_layer { @@ -97,14 +98,14 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd)); fin.read((char *) &hparams.n_head, sizeof(hparams.n_head)); fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer)); - fin.read((char *) &hparams.f16, sizeof(hparams.f16)); + fin.read((char *) &hparams.ftype, sizeof(hparams.ftype)); printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab); printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx); printf("%s: n_embd = %d\n", __func__, hparams.n_embd); printf("%s: n_head = %d\n", __func__, hparams.n_head); printf("%s: n_layer = %d\n", __func__, hparams.n_layer); - printf("%s: f16 = %d\n", __func__, hparams.f16); + printf("%s: ftype = %d\n", __func__, hparams.ftype); } // load vocab @@ -133,24 +134,13 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & // for the big tensors, we have the option to store the data in 16-bit floats or quantized // in order to save memory and also to speed up the computation - ggml_type wtype = GGML_TYPE_COUNT; - switch (model.hparams.f16) { - case 0: wtype = GGML_TYPE_F32; break; - case 1: wtype = GGML_TYPE_F16; break; - case 2: wtype = GGML_TYPE_Q4_0; break; - case 3: wtype = GGML_TYPE_Q4_1; break; - case 5: wtype = GGML_TYPE_Q4_2; break; - case 6: wtype = GGML_TYPE_Q4_3; break; - default: - { - fprintf(stderr, "%s: invalid model file '%s' (bad f16 value %d)\n", - __func__, fname.c_str(), model.hparams.f16); - return false; - } + ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype)); + if (wtype == GGML_TYPE_COUNT) { + fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n", + __func__, fname.c_str(), model.hparams.ftype); + return false; } - const ggml_type wtype2 = GGML_TYPE_F32; - auto & ctx = model.ctx; size_t ctx_size = 0; diff --git a/examples/gpt-2/quantize.cpp b/examples/gpt-2/quantize.cpp index f456a396..09b99ffe 100644 --- a/examples/gpt-2/quantize.cpp +++ b/examples/gpt-2/quantize.cpp @@ -24,7 +24,7 @@ struct gpt2_hparams { }; // quantize a model -bool gpt2_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_mtype mtype) { +bool gpt2_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) { gpt_vocab vocab; printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str()); @@ -76,7 +76,7 @@ bool gpt2_model_quantize(const std::string & fname_inp, const std::string & fnam fout.write((char *) &hparams.n_embd, sizeof(hparams.n_embd)); fout.write((char *) &hparams.n_head, sizeof(hparams.n_head)); fout.write((char *) &hparams.n_layer, sizeof(hparams.n_layer)); - fout.write((char *) &mtype, sizeof(hparams.f16)); + fout.write((char *) &ftype, sizeof(hparams.f16)); } // load vocab @@ -116,7 +116,7 @@ bool gpt2_model_quantize(const std::string & fname_inp, const std::string & fnam "model/h.*/mlp/c_proj/w", }; - if (!ggml_common_quantize_0(finp, fout, mtype, to_quant, {})) { + if (!ggml_common_quantize_0(finp, fout, ftype, to_quant, {})) { fprintf(stderr, "%s: failed to quantize model '%s'\n", __func__, fname_inp.c_str()); return false; } @@ -133,10 +133,7 @@ bool gpt2_model_quantize(const std::string & fname_inp, const std::string & fnam int main(int argc, char ** argv) { if (argc != 4) { fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]); - fprintf(stderr, " type = 2 - q4_0\n"); - fprintf(stderr, " type = 3 - q4_1\n"); - fprintf(stderr, " type = 5 - q4_2\n"); - fprintf(stderr, " type = 6 - q4_3\n"); + ggml_print_ftypes(stderr); return 1; } @@ -150,7 +147,7 @@ int main(int argc, char ** argv) { const std::string fname_inp = argv[1]; const std::string fname_out = argv[2]; - const int mtype = atoi(argv[3]); + const ggml_ftype ftype = ggml_parse_ftype(argv[3]); const int64_t t_main_start_us = ggml_time_us(); @@ -160,7 +157,7 @@ int main(int argc, char ** argv) { { const int64_t t_start_us = ggml_time_us(); - if (!gpt2_model_quantize(fname_inp, fname_out, ggml_mtype(mtype))) { + if (!gpt2_model_quantize(fname_inp, fname_out, ggml_ftype(ftype))) { fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str()); return 1; } diff --git a/examples/gpt-j/CMakeLists.txt b/examples/gpt-j/CMakeLists.txt index c8b5ecfe..3675b7df 100644 --- a/examples/gpt-j/CMakeLists.txt +++ b/examples/gpt-j/CMakeLists.txt @@ -3,7 +3,7 @@ set(TEST_TARGET gpt-j) add_executable(${TEST_TARGET} main.cpp) -target_link_libraries(${TEST_TARGET} PRIVATE ggml common) +target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml) # # gpt-j-quantize diff --git a/examples/gpt-j/main.cpp b/examples/gpt-j/main.cpp index eeca87fd..99ed4406 100644 --- a/examples/gpt-j/main.cpp +++ b/examples/gpt-j/main.cpp @@ -1,6 +1,7 @@ #include "ggml/ggml.h" #include "common.h" +#include "common-ggml.h" #include #include @@ -21,7 +22,7 @@ struct gptj_hparams { int32_t n_head = 16; int32_t n_layer = 28; int32_t n_rot = 64; - int32_t f16 = 1; + int32_t ftype = 1; }; struct gptj_layer { @@ -97,7 +98,7 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & fin.read((char *) &hparams.n_head, sizeof(hparams.n_head)); fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer)); fin.read((char *) &hparams.n_rot, sizeof(hparams.n_rot)); - fin.read((char *) &hparams.f16, sizeof(hparams.f16)); + fin.read((char *) &hparams.ftype, sizeof(hparams.ftype)); printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab); printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx); @@ -105,7 +106,7 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & printf("%s: n_head = %d\n", __func__, hparams.n_head); printf("%s: n_layer = %d\n", __func__, hparams.n_layer); printf("%s: n_rot = %d\n", __func__, hparams.n_rot); - printf("%s: f16 = %d\n", __func__, hparams.f16); + printf("%s: ftype = %d\n", __func__, hparams.ftype); } // load vocab @@ -134,24 +135,13 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & // for the big tensors, we have the option to store the data in 16-bit floats or quantized // in order to save memory and also to speed up the computation - ggml_type wtype = GGML_TYPE_COUNT; - switch (model.hparams.f16) { - case 0: wtype = GGML_TYPE_F32; break; - case 1: wtype = GGML_TYPE_F16; break; - case 2: wtype = GGML_TYPE_Q4_0; break; - case 3: wtype = GGML_TYPE_Q4_1; break; - case 5: wtype = GGML_TYPE_Q4_2; break; - case 6: wtype = GGML_TYPE_Q4_3; break; - default: - { - fprintf(stderr, "%s: invalid model file '%s' (bad f16 value %d)\n", - __func__, fname.c_str(), model.hparams.f16); - return false; - } + ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype)); + if (wtype == GGML_TYPE_COUNT) { + fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n", + __func__, fname.c_str(), model.hparams.ftype); + return false; } - const ggml_type wtype2 = GGML_TYPE_F32; - auto & ctx = model.ctx; size_t ctx_size = 0; diff --git a/examples/gpt-j/quantize.cpp b/examples/gpt-j/quantize.cpp index 08ceabd6..3f9f4b38 100644 --- a/examples/gpt-j/quantize.cpp +++ b/examples/gpt-j/quantize.cpp @@ -25,7 +25,7 @@ struct gptj_hparams { }; // quantize a model -bool gptj_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_mtype mtype) { +bool gptj_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) { gpt_vocab vocab; printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str()); @@ -79,7 +79,7 @@ bool gptj_model_quantize(const std::string & fname_inp, const std::string & fnam fout.write((char *) &hparams.n_head, sizeof(hparams.n_head)); fout.write((char *) &hparams.n_layer, sizeof(hparams.n_layer)); fout.write((char *) &hparams.n_rot, sizeof(hparams.n_rot)); - fout.write((char *) &mtype, sizeof(hparams.f16)); + fout.write((char *) &ftype, sizeof(hparams.f16)); } // load vocab @@ -114,7 +114,7 @@ bool gptj_model_quantize(const std::string & fname_inp, const std::string & fnam ".*weight", }; - if (!ggml_common_quantize_0(finp, fout, mtype, to_quant, {})) { + if (!ggml_common_quantize_0(finp, fout, ftype, to_quant, {})) { fprintf(stderr, "%s: failed to quantize model '%s'\n", __func__, fname_inp.c_str()); return false; } @@ -131,10 +131,7 @@ bool gptj_model_quantize(const std::string & fname_inp, const std::string & fnam int main(int argc, char ** argv) { if (argc != 4) { fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]); - fprintf(stderr, " type = 2 - q4_0\n"); - fprintf(stderr, " type = 3 - q4_1\n"); - fprintf(stderr, " type = 5 - q4_2\n"); - fprintf(stderr, " type = 6 - q4_3\n"); + ggml_print_ftypes(stderr); return 1; } @@ -148,7 +145,7 @@ int main(int argc, char ** argv) { const std::string fname_inp = argv[1]; const std::string fname_out = argv[2]; - const int mtype = atoi(argv[3]); + const ggml_ftype ftype = ggml_parse_ftype(argv[3]); const int64_t t_main_start_us = ggml_time_us(); @@ -158,7 +155,7 @@ int main(int argc, char ** argv) { { const int64_t t_start_us = ggml_time_us(); - if (!gptj_model_quantize(fname_inp, fname_out, ggml_mtype(mtype))) { + if (!gptj_model_quantize(fname_inp, fname_out, ggml_ftype(ftype))) { fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str()); return 1; } diff --git a/examples/stablelm/CMakeLists.txt b/examples/stablelm/CMakeLists.txt index 58c7d740..8066610f 100644 --- a/examples/stablelm/CMakeLists.txt +++ b/examples/stablelm/CMakeLists.txt @@ -3,7 +3,7 @@ set(TEST_TARGET stablelm) add_executable(${TEST_TARGET} main.cpp) -target_link_libraries(${TEST_TARGET} PRIVATE ggml common) +target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml) # # stablelm-quantize diff --git a/examples/stablelm/main.cpp b/examples/stablelm/main.cpp index cd7918be..3cf6b1cd 100644 --- a/examples/stablelm/main.cpp +++ b/examples/stablelm/main.cpp @@ -1,6 +1,7 @@ #include "ggml/ggml.h" #include "common.h" +#include "common-ggml.h" #include #include @@ -131,24 +132,13 @@ bool stablelm_model_load(const std::string & fname, stablelm_model & model, gpt_ // for the big tensors, we have the option to store the data in 16-bit floats or quantized // in order to save memory and also to speed up the computation - ggml_type wtype = GGML_TYPE_COUNT; - switch (model.hparams.ftype) { - case 0: wtype = GGML_TYPE_F32; break; - case 1: wtype = GGML_TYPE_F16; break; - case 2: wtype = GGML_TYPE_Q4_0; break; - case 3: wtype = GGML_TYPE_Q4_1; break; - case 5: wtype = GGML_TYPE_Q4_2; break; - case 6: wtype = GGML_TYPE_Q4_3; break; - default: - { - fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n", - __func__, fname.c_str(), model.hparams.ftype); - return false; - } + ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype)); + if (wtype == GGML_TYPE_COUNT) { + fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n", + __func__, fname.c_str(), model.hparams.ftype); + return false; } - const ggml_type wtype2 = GGML_TYPE_F32; - auto & ctx = model.ctx; size_t ctx_size = 0; diff --git a/examples/stablelm/quantize.cpp b/examples/stablelm/quantize.cpp index 90a4799d..6df1a061 100644 --- a/examples/stablelm/quantize.cpp +++ b/examples/stablelm/quantize.cpp @@ -25,7 +25,7 @@ struct stablelm_hparams { }; // quantize a model -bool stablelm_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_mtype mtype) { +bool stablelm_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) { gpt_vocab vocab; printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str()); @@ -79,7 +79,7 @@ bool stablelm_model_quantize(const std::string & fname_inp, const std::string & fout.write((char *) &hparams.n_head, sizeof(hparams.n_head)); fout.write((char *) &hparams.n_layer, sizeof(hparams.n_layer)); fout.write((char *) &hparams.n_rot, sizeof(hparams.n_rot)); - fout.write((char *) &mtype, sizeof(hparams.ftype)); + fout.write((char *) &ftype, sizeof(hparams.ftype)); } // load vocab @@ -106,7 +106,7 @@ bool stablelm_model_quantize(const std::string & fname_inp, const std::string & ".*weight", }; - if (!ggml_common_quantize_0(finp, fout, mtype, to_quant, {})) { + if (!ggml_common_quantize_0(finp, fout, ftype, to_quant, {})) { fprintf(stderr, "%s: failed to quantize model '%s'\n", __func__, fname_inp.c_str()); return false; } @@ -123,10 +123,7 @@ bool stablelm_model_quantize(const std::string & fname_inp, const std::string & int main(int argc, char ** argv) { if (argc != 4) { fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]); - fprintf(stderr, " type = 2 - q4_0\n"); - fprintf(stderr, " type = 3 - q4_1\n"); - fprintf(stderr, " type = 5 - q4_2\n"); - fprintf(stderr, " type = 6 - q4_3\n"); + ggml_print_ftypes(stderr); return 1; } @@ -140,7 +137,7 @@ int main(int argc, char ** argv) { const std::string fname_inp = argv[1]; const std::string fname_out = argv[2]; - const int mtype = atoi(argv[3]); + const ggml_ftype ftype = ggml_parse_ftype(argv[3]); const int64_t t_main_start_us = ggml_time_us(); @@ -150,7 +147,7 @@ int main(int argc, char ** argv) { { const int64_t t_start_us = ggml_time_us(); - if (!stablelm_model_quantize(fname_inp, fname_out, ggml_mtype(mtype))) { + if (!stablelm_model_quantize(fname_inp, fname_out, ggml_ftype(ftype))) { fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str()); return 1; } diff --git a/examples/whisper/quantize.cpp b/examples/whisper/quantize.cpp index e76b8ac4..994fabd4 100644 --- a/examples/whisper/quantize.cpp +++ b/examples/whisper/quantize.cpp @@ -36,7 +36,7 @@ struct whisper_filters { }; // quantize a model -bool whisper_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_mtype mtype) { +bool whisper_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) { gpt_vocab vocab; printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str()); @@ -103,7 +103,7 @@ bool whisper_model_quantize(const std::string & fname_inp, const std::string & f fout.write((char *) &hparams.n_text_head, sizeof(hparams.n_text_head)); fout.write((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer)); fout.write((char *) &hparams.n_mels, sizeof(hparams.n_mels)); - fout.write((char *) &mtype, sizeof(hparams.f16)); + fout.write((char *) &ftype, sizeof(hparams.f16)); } // load mel filters @@ -156,7 +156,7 @@ bool whisper_model_quantize(const std::string & fname_inp, const std::string & f "decoder.positional_embedding", }; - if (!ggml_common_quantize_0(finp, fout, mtype, { ".*" }, to_skip)) { + if (!ggml_common_quantize_0(finp, fout, ftype, { ".*" }, to_skip)) { fprintf(stderr, "%s: failed to quantize model '%s'\n", __func__, fname_inp.c_str()); return false; } @@ -187,7 +187,7 @@ int main(int argc, char ** argv) { const std::string fname_inp = argv[1]; const std::string fname_out = argv[2]; - const int mtype = atoi(argv[3]); + const int ftype = atoi(argv[3]); const int64_t t_main_start_us = ggml_time_us(); @@ -197,7 +197,7 @@ int main(int argc, char ** argv) { { const int64_t t_start_us = ggml_time_us(); - if (!whisper_model_quantize(fname_inp, fname_out, ggml_mtype(mtype))) { + if (!whisper_model_quantize(fname_inp, fname_out, ggml_ftype(ftype))) { fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str()); return 1; } diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index 27589078..d9d3d214 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -222,7 +222,10 @@ extern "C" { GGML_TYPE_Q4_1 = 3, GGML_TYPE_Q4_2 = 4, GGML_TYPE_Q4_3 = 5, - GGML_TYPE_Q8_0 = 6, + GGML_TYPE_Q5_0 = 6, + GGML_TYPE_Q5_1 = 7, + GGML_TYPE_Q8_0 = 8, + GGML_TYPE_Q8_1 = 9, GGML_TYPE_I8, GGML_TYPE_I16, GGML_TYPE_I32, @@ -832,6 +835,9 @@ extern "C" { GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist); @@ -876,6 +882,7 @@ extern "C" { quantize_row_q_t quantize_row_q_reference; quantize_row_q_t quantize_row_q_dot; vec_dot_q_t vec_dot_q; + enum ggml_type vec_dot_type; } quantize_fns_t; quantize_fns_t ggml_internal_get_quantize_fn(size_t i); diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu index fa511c1d..b1bd29b1 100644 --- a/src/ggml-cuda.cu +++ b/src/ggml-cuda.cu @@ -37,6 +37,30 @@ typedef struct { } block_q4_3; static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding"); +#define QK5_0 32 +typedef struct { + __half d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; +static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); + +#define QK5_1 32 +typedef struct { + __half d; // delta + __half m; // min + uint32_t qh; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; +static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); + +#define QK8_0 32 +typedef struct { + float d; // delta + int8_t qs[QK8_0]; // quants +} block_q8_0; +static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); + static __global__ void dequantize_block_q4_0(const void * vx, float * y) { const block_q4_0 * x = (const block_q4_0 *) vx; @@ -131,6 +155,80 @@ static __global__ void dequantize_block_q4_3(const void * vx, float * y) { } } +static __global__ void dequantize_block_q5_0(const void * vx, float * y) { + const block_q5_0 * x = (const block_q5_0 *) vx; + + const int i = blockIdx.x; + + const float d = x[i].d; + + const uint8_t * pp = x[i].qs; + + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + for (int l = 0; l < QK5_0; l += 2) { + const uint8_t vi = pp[l/2]; + + const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + + const int8_t vi0 = ((vi & 0xf) | vh0); + const int8_t vi1 = ((vi >> 4) | vh1); + + const float v0 = (vi0 - 16)*d; + const float v1 = (vi1 - 16)*d; + + y[i*QK5_0 + l + 0] = v0; + y[i*QK5_0 + l + 1] = v1; + } +} + +static __global__ void dequantize_block_q5_1(const void * vx, float * y) { + const block_q5_1 * x = (const block_q5_1 *) vx; + + const int i = blockIdx.x; + + const float d = x[i].d; + const float m = x[i].m; + + const uint8_t * pp = x[i].qs; + + const uint32_t qh = x[i].qh; + + for (int l = 0; l < QK5_1; l += 2) { + const uint8_t vi = pp[l/2]; + + const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + + const int8_t vi0 = (vi & 0xf) | vh0; + const int8_t vi1 = (vi >> 4) | vh1; + + const float v0 = vi0*d + m; + const float v1 = vi1*d + m; + + y[i*QK5_1 + l + 0] = v0; + y[i*QK5_1 + l + 1] = v1; + } +} + +static __global__ void dequantize_block_q8_0(const void * vx, float * y) { + const block_q8_0 * x = (const block_q8_0 *) vx; + + const int i = blockIdx.x; + + const float d = x[i].d; + + const int8_t * pp = x[i].qs; + + for (int l = 0; l < QK8_0; l++) { + const int8_t vi = pp[l]; + + y[i*QK8_0 + l] = vi*d; + } +} + void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK4_0; dequantize_block_q4_0<<>>(vx, y); @@ -151,6 +249,21 @@ void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t st dequantize_block_q4_3<<>>(vx, y); } +void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK5_0; + dequantize_block_q5_0<<>>(vx, y); +} + +void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK5_1; + dequantize_block_q5_1<<>>(vx, y); +} + +void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK8_0; + dequantize_block_q8_0<<>>(vx, y); +} + // buffer pool for cuda #define MAX_CUDA_BUFFERS 16 diff --git a/src/ggml-cuda.h b/src/ggml-cuda.h index 370bbc75..ed9b4418 100644 --- a/src/ggml-cuda.h +++ b/src/ggml-cuda.h @@ -35,6 +35,9 @@ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t st void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream); +void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); +void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream); +void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); #ifdef __cplusplus } diff --git a/src/ggml.c b/src/ggml.c index 6e46c0e5..b3504d17 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -328,6 +328,20 @@ static ggml_fp16_t table_exp_f16[1 << 16]; // precomputed f32 table for f16 (256 KB) static float table_f32_f16[1 << 16]; +#if defined(__ARM_NEON) +#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s +#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) +#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s) +#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s) +#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s) +#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s) +#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s) +#define B8(c,s ) B7(c,s, c), B7(c,s, s) + +// precomputed tables for expanding 8bits to 8 bytes (shl 4) +static const uint64_t table_b2b_u[1 << 8] = { B8(00, 10) }; +#endif + // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32, // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON. // This is also true for POWER9. @@ -477,6 +491,19 @@ static inline int hsum_i32_4(const __m128i a) { } #if __AVX2__ || __AVX512F__ +// spread 32 bits to 32 bytes { 0x00, 0xFF } +static inline __m256i bytes_from_bits_32(const uint8_t * x) { + uint32_t x32; + memcpy(&x32, x, sizeof(uint32_t)); + const __m256i shuf_mask = _mm256_set_epi64x( + 0x0303030303030303, 0x0202020202020202, + 0x0101010101010101, 0x0000000000000000); + __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask); + const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe); + bytes = _mm256_or_si256(bytes, bit_mask); + return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1)); +} + // Unpack 32 4-bit fields into 32 bytes // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) @@ -673,15 +700,38 @@ typedef struct { } block_q4_3; static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding"); +#define QK5_0 32 +typedef struct { + ggml_fp16_t d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; +static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); + +#define QK5_1 32 +typedef struct { + ggml_fp16_t d; // delta + ggml_fp16_t m; // min + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; +static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); + #define QK8_0 32 typedef struct { float d; // delta - float s0; // d * sum(qs[i]) low - float s1; // d * sum(qs[i]) high int8_t qs[QK8_0]; // quants } block_q8_0; -static_assert(sizeof(block_q8_0) == 3*sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); +static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); +#define QK8_1 32 +typedef struct { + float d; // delta + float s0; // d * sum(qs[i]) low + float s1; // d * sum(qs[i]) high + int8_t qs[QK8_1]; // quants +} block_q8_1; +static_assert(sizeof(block_q8_1) == 3*sizeof(float) + QK8_1, "wrong q8_1 block size/padding"); // reference implementation for deterministic creation of model files static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { @@ -692,13 +742,17 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r for (int i = 0; i < nb; i++) { float amax = 0.0f; // absolute max + float max = 0.0f; for (int l = 0; l < QK4_0; l++) { const float v = x[i*QK4_0 + l]; - amax = MAX(amax, fabsf(v)); + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } } - const float d = amax / ((1 << 3) - 1); + const float d = max / -8; const float id = d ? 1.0f/d : 0.0f; y[i].d = d; @@ -707,8 +761,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r const float v0 = x[i*QK4_0 + l + 0]*id; const float v1 = x[i*QK4_0 + l + 1]*id; - const uint8_t vi0 = (int8_t)roundf(v0) + 8; - const uint8_t vi1 = (int8_t)roundf(v1) + 8; + const uint8_t vi0 = MIN(15, (int8_t)roundf(v0) + 8); + const uint8_t vi1 = MIN(15, (int8_t)roundf(v1) + 8); assert(vi0 < 16); assert(vi1 < 16); @@ -728,28 +782,42 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int #if defined(__POWER9_VECTOR__) const vector float v85 = vec_splats(8.5f); + const vector signed int v15 = vec_splats(15); for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max + float max = 0.0f; + float min = 0.0f; vector float srcv [8]; - vector float asrcv[8]; - vector float amaxv[8]; + vector float maxv[8]; + vector float minv[8]; for (int l = 0; l < 8; l++) srcv[l] = *(vector float *)(x + i*32 + 4*l); - for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]); - - for (int l = 0; l < 4; l++) amaxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]); - //for (int l = 0; l < 2; l++) amaxv[4*l] = vec_max(amaxv[4*l], amaxv[4*l+2]); - amaxv[0] = vec_max(amaxv[0], amaxv[2]); - amaxv[4] = vec_max(amaxv[4], amaxv[6]); - //for (int l = 0; l < 1; l++) amaxv[8*l] = vec_max(amaxv[8*l], amaxv[8*l+4]); - amaxv[0] = vec_max(amaxv[0], amaxv[4]); - - amax = MAX( - MAX(vec_extract(amaxv[0], 0), vec_extract(amaxv[0], 1)), - MAX(vec_extract(amaxv[0], 2), vec_extract(amaxv[0], 3))); - - const float d = amax / ((1 << 3) - 1); + //for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]); + + for (int l = 0; l < 4; l++) maxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]); + //for (int l = 0; l < 2; l++) maxv[4*l] = vec_max(maxv[4*l], maxv[4*l+2]); + maxv[0] = vec_max(maxv[0], maxv[2]); + maxv[4] = vec_max(maxv[4], maxv[6]); + //for (int l = 0; l < 1; l++) maxv[8*l] = vec_max(maxv[8*l], maxv[8*l+4]); + maxv[0] = vec_max(maxv[0], maxv[4]); + + for (int l = 0; l < 4; l++) minv[2*l] = vec_min(asrcv[2*l], asrcv[2*l+1]); + //for (int l = 0; l < 2; l++) minv[4*l] = vec_min(minv[4*l], minv[4*l+2]); + minv[0] = vec_min(minv[0], minv[2]); + minv[4] = vec_min(minv[4], minv[6]); + //for (int l = 0; l < 1; l++) minv[8*l] = vec_min(minv[8*l], minv[8*l+4]); + minv[0] = vec_min(minv[0], minv[4]); + + + max = MAX( + MAX(vec_extract(maxv[0], 0), vec_extract(maxv[0], 1)), + MAX(vec_extract(maxv[0], 2), vec_extract(maxv[0], 3))); + min = MIN( + MIN(vec_extract(minv[0], 0), vec_extract(minv[0], 1)), + MIN(vec_extract(minv[0], 2), vec_extract(minv[0], 3))); + + const float magnitude = max >= fabsf(min) ? max : min; + const float d = magnitude / -8; const float id = d ? 1.0/d : 0.0; y[i].d = d; @@ -759,27 +827,33 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int for (int l = 0; l < 8; l++) { const vector float vf = vec_madd(srcv[l], vid, v85); const vector signed int vi = vec_signed(vf); + const vector signed int vc = vec_min(vi, v15); - pb[2*l + 0] = vec_extract(vi, 0) | (vec_extract(vi, 1) << 4); - pb[2*l + 1] = vec_extract(vi, 2) | (vec_extract(vi, 3) << 4); + pb[2*l + 0] = vec_extract(vc, 0) | (vec_extract(vc, 1) << 4); + pb[2*l + 1] = vec_extract(vc, 2) | (vec_extract(vc, 3) << 4); } } #elif __ARM_NEON for (int i = 0; i < nb; i++) { float32x4_t srcv [8]; - float32x4_t asrcv[8]; - float32x4_t amaxv[8]; + float32x4_t maxv[8]; + float32x4_t minv[8]; for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l); - for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]); - for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]); - for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]); - for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]); + for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l+1]); + for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l+2]); + for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l+4]); - const float amax = vmaxvq_f32(amaxv[0]); + for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l+1]); + for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l+2]); + for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l+4]); + + const float max = vmaxvq_f32(maxv[0]); + const float min = vminvq_f32(minv[0]); - const float d = amax / ((1 << 3) - 1); + const float magnitude = max >= fabsf(min) ? max : min; + const float d = magnitude / -8; const float id = d ? 1.0f/d : 0.0f; y[i].d = d; @@ -788,9 +862,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int const float32x4_t v = vmulq_n_f32(srcv[l], id); const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f)); const int32x4_t vi = vcvtq_s32_f32(vf); + const int32x4_t vc = vminq_s32(vi, vdupq_n_s32(15)); - y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4); - y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4); + y[i].qs[2*l + 0] = vgetq_lane_s32(vc, 0) | (vgetq_lane_s32(vc, 1) << 4); + y[i].qs[2*l + 1] = vgetq_lane_s32(vc, 2) | (vgetq_lane_s32(vc, 3) << 4); } } #elif defined(__AVX2__) @@ -802,22 +877,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int __m256 v3 = _mm256_loadu_ps( x + 24 ); x += 32; - // Compute max(abs(e)) for the block - const __m256 signBit = _mm256_set1_ps( -0.0f ); - __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + // Compute max for the block + __m256 max = _mm256_max_ps( v0, v1 ); + __m256 maxTmp = _mm256_max_ps( v2, v3 ); + max = _mm256_max_ps( max, maxTmp ); - __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) ); max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); const float maxScalar = _mm_cvtss_f32( max4 ); + // Compute min for the block + __m256 min = _mm256_min_ps( v0, v1 ); + __m256 minTmp = _mm256_min_ps( v2, v3 ); + min = _mm256_min_ps( min, minTmp ); + + __m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) ); + min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) ); + min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) ); + const float minScalar = _mm_cvtss_f32( min4 ); + // Quantize these floats - const float d = maxScalar / 7.0f; + const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar; + const float d = magnitude / -8.0f; y[i].d = d; - const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f; + const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f; const __m256 mul = _mm256_set1_ps( id ); // Apply the multiplier @@ -850,9 +934,11 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); i0 = _mm256_permutevar8x32_epi32( i0, perm ); - // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ] + // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ] const __m256i off = _mm256_set1_epi8( 8 ); i0 = _mm256_add_epi8( i0, off ); + const __m256i maxNibble = _mm256_set1_epi8( 15 ); + i0 = _mm256_min_epi8( i0, maxNibble ); // Compress the vector into 4 bit/value, and store __m128i res = packNibbles( i0 ); @@ -867,22 +953,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int __m256 v3 = _mm256_loadu_ps( x + 24 ); x += 32; - // Compute max(abs(e)) for the block - const __m256 signBit = _mm256_set1_ps( -0.0f ); - __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + // Compute max for the block + __m256 max = _mm256_max_ps( v0, v1 ); + __m256 maxTmp = _mm256_max_ps( v2, v3 ); + max = _mm256_max_ps( max, maxTmp ); - __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) ); max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); const float maxScalar = _mm_cvtss_f32( max4 ); + // Compute min for the block + __m256 min = _mm256_min_ps( v0, v1 ); + __m256 minTmp = _mm256_min_ps( v2, v3 ); + min = _mm256_min_ps( min, minTmp ); + + __m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) ); + min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) ); + min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) ); + const float minScalar = _mm_cvtss_f32( min4 ); + // Quantize these floats - const float d = maxScalar / 7.0f; + const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar; + const float d = magnitude / -8.0f; y[i].d = d; - const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f; + const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f; const __m256 mul = _mm256_set1_ps( id ); // Apply the multiplier @@ -923,10 +1018,13 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int ni0 = _mm_packs_epi16( ni0, ni2 ); ni4 = _mm_packs_epi16( ni4, ni6 ); - // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ] - const __m128i off = _mm_set1_epi8( 8); + // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ] + const __m128i off = _mm_set1_epi8( 8 ); ni0 = _mm_add_epi8( ni0, off ); ni4 = _mm_add_epi8( ni4, off ); + const __m128i maxNibble = _mm_set1_epi8( 15 ); + ni0 = _mm_min_epi8( ni0, maxNibble ); + ni4 = _mm_min_epi8( ni4, maxNibble ); // Compress the vector into 4 bit/value, and store __m128i res = packNibbles( ni0, ni4 ); @@ -934,24 +1032,32 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int } #elif defined(__wasm_simd128__) for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max + float max = 0.0f; + float min = 0.0f; v128_t srcv [8]; - v128_t asrcv[8]; - v128_t amaxv[8]; + v128_t maxv[8]; + v128_t minv[8]; for (int l = 0; l < 8; l++) srcv[l] = wasm_v128_load(x + i*32 + 4*l); - for (int l = 0; l < 8; l++) asrcv[l] = wasm_f32x4_abs(srcv[l]); - for (int l = 0; l < 4; l++) amaxv[2*l] = wasm_f32x4_max(asrcv[2*l], asrcv[2*l+1]); - for (int l = 0; l < 2; l++) amaxv[4*l] = wasm_f32x4_max(amaxv[4*l], amaxv[4*l+2]); - for (int l = 0; l < 1; l++) amaxv[8*l] = wasm_f32x4_max(amaxv[8*l], amaxv[8*l+4]); + for (int l = 0; l < 4; l++) maxv[2*l] = wasm_f32x4_max(srcv[2*l], srcv[2*l+1]); + for (int l = 0; l < 2; l++) maxv[4*l] = wasm_f32x4_max(maxv[4*l], maxv[4*l+2]); + for (int l = 0; l < 1; l++) maxv[8*l] = wasm_f32x4_max(maxv[8*l], maxv[8*l+4]); + + for (int l = 0; l < 4; l++) minv[2*l] = wasm_f32x4_min(srcv[2*l], srcv[2*l+1]); + for (int l = 0; l < 2; l++) minv[4*l] = wasm_f32x4_min(minv[4*l], minv[4*l+2]); + for (int l = 0; l < 1; l++) minv[8*l] = wasm_f32x4_min(minv[8*l], minv[8*l+4]); - amax = MAX( - MAX(wasm_f32x4_extract_lane(amaxv[0], 0), wasm_f32x4_extract_lane(amaxv[0], 1)), - MAX(wasm_f32x4_extract_lane(amaxv[0], 2), wasm_f32x4_extract_lane(amaxv[0], 3))); + max = MAX( + MAX(wasm_f32x4_extract_lane(maxv[0], 0), wasm_f32x4_extract_lane(maxv[0], 1)), + MAX(wasm_f32x4_extract_lane(maxv[0], 2), wasm_f32x4_extract_lane(maxv[0], 3))); + min = MIN( + MIN(wasm_f32x4_extract_lane(minv[0], 0), wasm_f32x4_extract_lane(minv[0], 1)), + MIN(wasm_f32x4_extract_lane(minv[0], 2), wasm_f32x4_extract_lane(minv[0], 3))); - const float d = amax / ((1 << 3) - 1); + const float magnitude = max >= fabsf(min) ? max : min; + const float d = magnitude / -8; const float id = d ? 1.0/d : 0.0; y[i].d = d; @@ -960,9 +1066,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id)); const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f)); const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf); + const v128_t vc = wasm_i32x4_min_u(vi, wasm_i32x4_splat(15)); - y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4); - y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4); + y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vc, 0) | (wasm_i32x4_extract_lane(vc, 1) << 4); + y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vc, 2) | (wasm_i32x4_extract_lane(vc, 3) << 4); } } #else @@ -1143,13 +1250,17 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r for (int i = 0; i < nb; i++) { float amax = 0.0f; // absolute max + float max = 0.0f; for (int l = 0; l < QK4_2; l++) { const float v = x[i*QK4_2 + l]; - amax = MAX(amax, fabsf(v)); + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } } - const float d = amax / ((1 << 3) - 1); + const float d = max / -8; const float id = d ? 1.0f/d : 0.0f; @@ -1159,93 +1270,14 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r const float v0 = x[i*QK4_2 + l + 0]*id; const float v1 = x[i*QK4_2 + l + 1]*id; - const uint8_t vi0 = (uint8_t)(v0 + 8.5f); - const uint8_t vi1 = (uint8_t)(v1 + 8.5f); - - assert(vi0 < 16); - assert(vi1 < 16); - - y[i].qs[l/2] = vi0 | (vi1 << 4); - } - } -} - -static inline int nearest_int(float fval) { - assert(fval <= 4194303.f); - float val = fval + 12582912.f; - int i; memcpy(&i, &val, sizeof(int)); - return (i & 0x007fffff) - 0x00400000; -} - -static float kquantize_q4_with_bounds(int n, int nmin, int nmax, const float * restrict X, int nCandidates, - const float * restrict candidates, int8_t * restrict L) { - assert (nmin >= INT8_MIN); - assert (nmax <= INT8_MAX); - float amax = 0; - for (int i=0; i sumlxM2*suml2P) { - if (sumlxP2 > best*suml2P) { - best = sumlxP2/suml2P; bestScale = iscale; - } - } else { - if (sumlxM2 > best*suml2M) { - best = sumlxM2/suml2M; bestScale = -iscale; - } - } - } - float sumlx = 0; int suml2 = 0; - for (int i=0; i> 4) << (l + 0); + qh |= ((vi1 & 0x10) >> 4) << (l + 1); + } + + memcpy(&y[i].qh, &qh, sizeof(y[i].qh)); + } +} + +static void quantize_row_q5_0(const float * restrict x, void * restrict vy, int k) { + assert(k % QK5_0 == 0); + + block_q5_0 * restrict y = vy; + + quantize_row_q5_0_reference(x, y, k); +} + +static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) { + assert(k % QK5_1 == 0); + const int nb = k / QK5_1; + + for (int i = 0; i < nb; i++) { + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int l = 0; l < QK5_1; l++) { + const float v = x[i*QK5_1 + l]; + if (v < min) min = v; + if (v > max) max = v; + } + + const float d = (max - min) / ((1 << 5) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + y[i].m = GGML_FP32_TO_FP16(min); + + uint32_t qh = 0; + + for (int l = 0; l < QK5_1; l += 2) { + const float v0 = (x[i*QK5_1 + l + 0] - min)*id; + const float v1 = (x[i*QK5_1 + l + 1] - min)*id; + + const uint32_t vi0 = (int) (v0 + 0.5f); + const uint32_t vi1 = (int) (v1 + 0.5f); + + y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4); + + // get the 5-th bit and store it in qh at the right position + qh |= ((vi0 & 0x10) >> 4) << (l + 0); + qh |= ((vi1 & 0x10) >> 4) << (l + 1); + } + + memcpy(&y[i].qh, &qh, sizeof(y[i].qh)); + } +} + +static void quantize_row_q5_1(const float * restrict x, void * restrict vy, int k) { + assert(k % QK5_1 == 0); + + block_q5_1 * restrict y = vy; + + quantize_row_q5_1_reference(x, y, k); +} + // reference implementation for deterministic creation of model files static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) { assert(k % QK8_0 == 0); @@ -1320,18 +1447,52 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r y[i].d = d; + for (int l = 0; l < QK8_0; ++l) { + const float v0 = x[i*QK8_0 + l]*id; + + y[i].qs[l] = roundf(v0); + } + } +} + +static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { + assert(k % QK8_0 == 0); + + block_q8_0 * restrict y = vy; + + quantize_row_q8_0_reference(x, y, k); +} + +// reference implementation for deterministic creation of model files +static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) { + assert(k % QK8_1 == 0); + const int nb = k / QK8_1; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int l = 0; l < QK8_1; l++) { + const float v = x[i*QK8_1 + l]; + amax = MAX(amax, fabsf(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + int sum0 = 0; int sum1 = 0; - for (int l = 0; l < QK8_0/2; ++l) { - const float v0 = x[i*QK8_0 + l]*id; - const float v1 = x[i*QK8_0 + QK8_0/2 + l]*id; + for (int l = 0; l < QK8_1/2; ++l) { + const float v0 = x[i*QK8_1 + l]*id; + const float v1 = x[i*QK8_1 + QK8_1/2 + l]*id; y[i].qs[ l] = roundf(v0); - y[i].qs[QK8_0/2 + l] = roundf(v1); + y[i].qs[QK8_1/2 + l] = roundf(v1); sum0 += y[i].qs[ l]; - sum1 += y[i].qs[QK8_0/2 + l]; + sum1 += y[i].qs[QK8_1/2 + l]; } y[i].s0 = d * sum0; @@ -1339,11 +1500,11 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r } } -static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { - assert(k % QK8_0 == 0); - const int nb = k / QK8_0; +static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { + assert(k % QK8_1 == 0); + const int nb = k / QK8_1; - block_q8_0 * restrict y = vy; + block_q8_1 * restrict y = vy; #if defined(__ARM_NEON) for (int i = 0; i < nb; i++) { @@ -1497,7 +1658,7 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int } #else // scalar - quantize_row_q8_0_reference(x, y, k); + quantize_row_q8_1_reference(x, y, k); #endif } @@ -1551,7 +1712,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in const uint8x8_t v8 = vld1_u8(pp + l/2); // Expand 4-bit qs to 8-bit bytes - const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f)); + const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F)); const uint8x8_t v1 = vshr_n_u8(v8, 4); // Convert to signed 8-bit integers @@ -1601,7 +1762,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in for (int l = 0; l < QK4_0; l += 2) { const uint8_t vi = pp[l/2]; - const int8_t vi0 = vi & 0xf; + const int8_t vi0 = vi & 0x0F; const int8_t vi1 = vi >> 4; const float v0 = (vi0 - 8)*d; @@ -1667,7 +1828,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in const uint8x8_t v8 = vld1_u8(pp + l/2); // Expand 4-bit qs to 8-bit bytes - const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f)); + const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F)); const uint8x8_t v1 = vshr_n_u8(v8, 4); // Interleave and combine @@ -1709,7 +1870,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in for (int l = 0; l < QK4_1; l += 2) { const uint8_t vi = pp[l/2]; - const int8_t vi0 = vi & 0xf; + const int8_t vi0 = vi & 0x0F; const int8_t vi1 = vi >> 4; const float v0 = vi0*d + m; @@ -1739,7 +1900,7 @@ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, in for (int l = 0; l < QK4_2; l += 2) { const uint8_t vi = pp[l/2]; - const int8_t vi0 = vi & 0xf; + const int8_t vi0 = vi & 0x0F; const int8_t vi1 = vi >> 4; const float v0 = (vi0 - 8)*d; @@ -1769,7 +1930,7 @@ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, in for (int l = 0; l < QK4_3; l += 2) { const uint8_t vi = pp[l/2]; - const int8_t vi0 = vi & 0xf; + const int8_t vi0 = vi & 0x0F; const int8_t vi1 = vi >> 4; const float v0 = vi0*d + m; @@ -1784,10 +1945,103 @@ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, in } } +static void dequantize_row_q5_0(const void * restrict vx, float * restrict y, int k) { + assert(k % QK5_0 == 0); + const int nb = k / QK5_0; + + const block_q5_0 * restrict x = vx; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict pp = x[i].qs; + + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + for (int l = 0; l < QK5_0; l += 2) { + const uint8_t vi = pp[l/2]; + + // extract the 5-th bit from qh + const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + + const int8_t vi0 = (vi & 0x0F) | vh0; + const int8_t vi1 = (vi >> 4) | vh1; + + const float v0 = (vi0 - 16)*d; + const float v1 = (vi1 - 16)*d; + + y[i*QK5_0 + l + 0] = v0; + y[i*QK5_0 + l + 1] = v1; + + assert(!isnan(y[i*QK5_0 + l + 0])); + assert(!isnan(y[i*QK5_0 + l + 1])); + } + } +} + +static void dequantize_row_q5_1(const void * restrict vx, float * restrict y, int k) { + assert(k % QK5_1 == 0); + const int nb = k / QK5_1; + + const block_q5_1 * restrict x = vx; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + const float m = GGML_FP16_TO_FP32(x[i].m); + + const uint8_t * restrict pp = x[i].qs; + + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + for (int l = 0; l < QK5_1; l += 2) { + const uint8_t vi = pp[l/2]; + + // extract the 5-th bit from qh + const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + + const uint8_t vi0 = (vi & 0x0F) | vh0; + const uint8_t vi1 = (vi >> 4) | vh1; + + const float v0 = vi0*d + m; + const float v1 = vi1*d + m; + + y[i*QK5_1 + l + 0] = v0; + y[i*QK5_1 + l + 1] = v1; + + assert(!isnan(y[i*QK5_1 + l + 0])); + assert(!isnan(y[i*QK5_1 + l + 1])); + } + } +} + +static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, int k) { + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + const block_q8_0 * restrict x = vx; + + for (int i = 0; i < nb; i++) { + const float d = x[i].d; + + const int8_t * restrict pp = x[i].qs; + + for (int l = 0; l < QK8_0; ++l) { + y[i*QK8_0 + l] = pp[l]*d; + } + } +} + static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); -static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); +static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); -static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); +static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); +static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); +static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); +static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_0] = { @@ -1796,34 +2050,63 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference, .quantize_row_q_dot = quantize_row_q8_0, .vec_dot_q = ggml_vec_dot_q4_0_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, }, [GGML_TYPE_Q4_1] = { .dequantize_row_q = dequantize_row_q4_1, .quantize_row_q = quantize_row_q4_1, .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference, - .quantize_row_q_dot = quantize_row_q8_0, - .vec_dot_q = ggml_vec_dot_q4_1_q8_0, + .quantize_row_q_dot = quantize_row_q8_1, + .vec_dot_q = ggml_vec_dot_q4_1_q8_1, + .vec_dot_type = GGML_TYPE_Q8_1, }, [GGML_TYPE_Q4_2] = { .dequantize_row_q = dequantize_row_q4_2, .quantize_row_q = quantize_row_q4_2, - .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_rmse, //quantize_row_q4_2_reference, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference, .quantize_row_q_dot = quantize_row_q8_0, .vec_dot_q = ggml_vec_dot_q4_2_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, }, [GGML_TYPE_Q4_3] = { .dequantize_row_q = dequantize_row_q4_3, .quantize_row_q = quantize_row_q4_3, - .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference, // TODO: RMSE optimization + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference, + .quantize_row_q_dot = quantize_row_q8_1, + .vec_dot_q = ggml_vec_dot_q4_3_q8_1, + .vec_dot_type = GGML_TYPE_Q8_1, + }, + [GGML_TYPE_Q5_0] = { + .dequantize_row_q = dequantize_row_q5_0, + .quantize_row_q = quantize_row_q5_0, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_0_reference, .quantize_row_q_dot = quantize_row_q8_0, - .vec_dot_q = ggml_vec_dot_q4_3_q8_0, + .vec_dot_q = ggml_vec_dot_q5_0_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + }, + [GGML_TYPE_Q5_1] = { + .dequantize_row_q = dequantize_row_q5_1, + .quantize_row_q = quantize_row_q5_1, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_1_reference, + .quantize_row_q_dot = quantize_row_q8_1, + .vec_dot_q = ggml_vec_dot_q5_1_q8_1, + .vec_dot_type = GGML_TYPE_Q8_1, }, [GGML_TYPE_Q8_0] = { - .dequantize_row_q = NULL, // TODO + .dequantize_row_q = dequantize_row_q8_0, .quantize_row_q = quantize_row_q8_0, .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference, .quantize_row_q_dot = quantize_row_q8_0, + .vec_dot_q = ggml_vec_dot_q8_0_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + }, + [GGML_TYPE_Q8_1] = { + .dequantize_row_q = NULL, // TODO + .quantize_row_q = quantize_row_q8_1, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_1_reference, + .quantize_row_q_dot = quantize_row_q8_1, .vec_dot_q = NULL, // TODO + .vec_dot_type = GGML_TYPE_Q8_1, }, }; @@ -2439,17 +2722,14 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); - float sum8 = 0; - for (int i = 0; i < nb; i += 2) { const block_q4_0 * restrict x0 = &x[i + 0]; const block_q4_0 * restrict x1 = &x[i + 1]; const block_q8_0 * restrict y0 = &y[i + 0]; const block_q8_0 * restrict y1 = &y[i + 1]; - sum8 += x0->d * (y0->s0 + y0->s1) + x1->d * (y1->s0 + y1->s1); - - const uint8x16_t m4b = vdupq_n_u8(0xf); + const uint8x16_t m4b = vdupq_n_u8(0x0F); + const int8x16_t s8b = vdupq_n_s8(0x8); const uint8x16_t v0_0 = vld1q_u8(x0->qs); const uint8x16_t v0_1 = vld1q_u8(x1->qs); @@ -2460,6 +2740,12 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + // sub 8 + const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); + const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); + const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); + const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); + // load y const int8x16_t v1_0l = vld1q_s8(y0->qs); const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); @@ -2474,21 +2760,21 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * #if defined(__ARM_FEATURE_DOTPROD) // dot product into int32x4_t - const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs); - const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs); + const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs); + const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs); sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d); #else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs)); + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs)); - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs)); + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs)); const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); @@ -2500,7 +2786,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * #endif } - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) - 8 * sum8; + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -2578,8 +2864,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * for (int j = 0; j < QK8_0/2; j++) { const uint8_t v0 = p0[j]; - const int i0 = (int8_t) (v0 & 0xf) - 8; - const int i1 = (int8_t) (v0 >> 4) - 8; + const int i0 = (int8_t) (v0 & 0x0F) - 8; + const int i1 = (int8_t) (v0 >> 4) - 8; const int i2 = p1[2*j + 0]; const int i3 = p1[2*j + 1]; @@ -2592,14 +2878,14 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * #endif } -static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int nb = n / QK8_0; +static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int nb = n / QK8_1; - assert(n % QK8_0 == 0); + assert(n % QK8_1 == 0); assert(nb % 2 == 0); const block_q4_1 * restrict x = vx; - const block_q8_0 * restrict y = vy; + const block_q8_1 * restrict y = vy; // TODO: add AVX / WASM SIMD / etc #if defined(__ARM_NEON) @@ -2611,12 +2897,12 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * for (int i = 0; i < nb; i += 2) { const block_q4_1 * restrict x0 = &x[i + 0]; const block_q4_1 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i + 0]; - const block_q8_0 * restrict y1 = &y[i + 1]; + const block_q8_1 * restrict y0 = &y[i + 0]; + const block_q8_1 * restrict y1 = &y[i + 1]; summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1); - const uint8x16_t m4b = vdupq_n_u8(0xf); + const uint8x16_t m4b = vdupq_n_u8(0x0F); const uint8x16_t v0_0 = vld1q_u8(x0->qs); const uint8x16_t v0_1 = vld1q_u8(x1->qs); @@ -2710,11 +2996,11 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * const int8_t * restrict p1 = y[i].qs; // TODO: this is very slow .. - for (int j = 0; j < QK8_0/2; j++) { + for (int j = 0; j < QK8_1/2; j++) { const uint8_t v0 = p0[j]; - const float f0 = d0*(v0 & 0xf) + m0; - const float f1 = d0*(v0 >> 4) + m0; + const float f0 = d0*(v0 & 0x0F) + m0; + const float f1 = d0*(v0 >> 4) + m0; const float f2 = d1*p1[2*j + 0]; const float f3 = d1*p1[2*j + 1]; @@ -2749,71 +3035,342 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * const block_q8_0 * restrict y0 = &y[i + 0]; const block_q8_0 * restrict y1 = &y[i + 1]; - const uint8x16_t m4b = vdupq_n_u8(0xf); + const uint8x16_t m4b = vdupq_n_u8(0x0F); const int8x16_t s8b = vdupq_n_s8(0x8); - const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs)); - const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs)); + const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs)); + const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs)); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // sub 8 + const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); + const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); + const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); + const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); + + // interleave + const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs); + const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs); + const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs); + const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + +#if defined(__ARM_FEATURE_DOTPROD) + sumv0 = vmlaq_n_f32(sumv0, vaddq_f32( + vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)), + vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d); + + sumv1 = vmlaq_n_f32(sumv1, vaddq_f32( + vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)), + vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h)); + + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h)); + + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); + const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); + + sumv0 = vmlaq_n_f32(sumv0, vaddq_f32( + vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)), + vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d); + + sumv1 = vmlaq_n_f32(sumv1, vaddq_f32( + vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)), + vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d); +#endif + } + + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; i++) { + /* Compute combined scale for the block */ + const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d)); + const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d)); + const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d)); + + __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs); + __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs); + __m256i bx = _mm256_set_m128i(bx1, bx0); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m256i off = _mm256_set1_epi8(8); + bx = _mm256_sub_epi8(bx, off); + + __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(bx, by); + + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps(d, q, acc); + } + + *s = hsum_float_8(acc); +#else + // scalar + float sumf = 0.0; + for (int i = 0; i < nb; i++) { + const uint8_t * restrict x0 = x[2*i + 0].qs; + const uint8_t * restrict x1 = x[2*i + 1].qs; + const int8_t * restrict y0 = y[i].qs; + + const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d); + const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d); + + int sumi_0 = 0; + int sumi_1 = 0; + + for (int j = 0; j < QK8_0/4; j++) { + const uint8_t v0 = x0[j]; + const uint8_t v1 = x1[j]; + + const int i0_0 = (int8_t) (v0 & 0x0F) - 8; + const int i1_0 = (int8_t) (v0 >> 4) - 8; + + const int i0_1 = (int8_t) (v1 & 0x0F) - 8; + const int i1_1 = (int8_t) (v1 >> 4) - 8; + + const int i2_0 = y0[2*j + 0]; + const int i3_0 = y0[2*j + 1]; + + const int i2_1 = y0[2*(j + QK8_0/4) + 0]; + const int i3_1 = y0[2*(j + QK8_0/4) + 1]; + + sumi_0 += i0_0*i2_0 + i1_0*i3_0; + sumi_1 += i0_1*i2_1 + i1_1*i3_1; + } + + sumf += (d0 * y[i].d) * sumi_0; + sumf += (d1 * y[i].d) * sumi_1; + } + *s = sumf; +#endif +} + +static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int nb = n / QK8_1; + + assert(n % QK8_1 == 0); + assert(nb % 2 == 0); + assert(QK8_1 == 2*QK4_3); + + const block_q4_3 * restrict x = vx; + const block_q8_1 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + float summs0 = 0.0f; + float summs1 = 0.0f; + + for (int i = 0; i < nb; ++i) { + const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0]; + const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1]; + + const block_q8_1 * restrict y0 = &y[i + 0]; + + summs0 += GGML_FP16_TO_FP32(x0_0->m) * y0->s0; + summs1 += GGML_FP16_TO_FP32(x0_1->m) * y0->s1; + + const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs)); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0x0F))); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + + // interleave + const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h); + const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + + const float x0_0d = GGML_FP16_TO_FP32(x0_0->d); + const float x0_1d = GGML_FP16_TO_FP32(x0_1->d); + +#if defined(__ARM_FEATURE_DOTPROD) + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h)); + + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph0), x0_1d*y0->d); +#endif + } + + *s = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs0 + summs1; +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + float summs = 0.0f; + + // Main loop + for (int i = 0; i < nb; i++) { + const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d)); + const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d)); + const __m256 dx = _mm256_set_m128(d1, d0); + + summs += GGML_FP16_TO_FP32(x[2*i + 0].m) * y[i].s0 + + GGML_FP16_TO_FP32(x[2*i + 1].m) * y[i].s1; + + const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs); + const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs); + const __m256i bx = _mm256_set_m128i(bx1, bx0); + + const __m256 dy = _mm256_broadcast_ss(&y[i].d); + const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(bx, by); + + acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); + } + + *s = hsum_float_8(acc) + summs; +#else + // scalar + float sumf = 0.0; + for (int i = 0; i < nb; i++) { + const uint8_t * restrict x0 = x[2*i + 0].qs; + const uint8_t * restrict x1 = x[2*i + 1].qs; + const int8_t * restrict y0 = y[i].qs; + + const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d); + const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m); + const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d); + const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m); + + int sxy_0 = 0; + int sxy_1 = 0; + + for (int j = 0; j < QK8_1/4; j++) { + const uint8_t v0 = x0[j]; + const uint8_t v1 = x1[j]; + + const int x0_0 = v0 & 0x0F; + const int x1_0 = v0 >> 4; + + const int x0_1 = v1 & 0x0F; + const int x1_1 = v1 >> 4; + + const int y0_0 = y0[2*j + 0]; + const int y1_0 = y0[2*j + 1]; + + const int y0_1 = y0[2*(j + QK8_1/4) + 0]; + const int y1_1 = y0[2*(j + QK8_1/4) + 1]; + + sxy_0 += x0_0*y0_0 + x1_0*y1_0; + sxy_1 += x0_1*y0_1 + x1_1*y1_1; + } + + sumf += (d0*sxy_0 + d1*sxy_1)*y[i].d + m0*y[i].s0 + m1*y[i].s1; + } + *s = sumf; +#endif +} + +static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int nb = n / QK8_0; + + assert(n % QK8_0 == 0); + assert(nb % 2 == 0); + assert(QK8_0 == QK5_0); + + const block_q5_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv = vdupq_n_f32(0.0f); + + uint64_t tmp[4]; + + for (int i = 0; i < nb; ++i) { + const block_q5_0 * restrict x0 = &x[i]; + const block_q8_0 * restrict y0 = &y[i]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + const int8x16_t s16b = vdupq_n_s8(0x10); + + // extract the 5th bit + uint32_t qh; + memcpy(&qh, x0->qh, sizeof(qh)); + + tmp[0] = table_b2b_u[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b_u[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b_u[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b_u[(qh >> 24) ]; + + const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0)); + const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2)); + + const uint8x16_t v0 = vld1q_u8(x0->qs); // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // sub 8 - const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); - const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); - const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); - const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); + const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, m4b)); + const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4)); // interleave - const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs); - const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs); - const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs); - const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs); + const int8x16_t v0lz = vzip1q_s8(v0l, v0h); + const int8x16_t v0hz = vzip2q_s8(v0l, v0h); + + // add high bit and sub 16 + const int8x16_t v0lf = vsubq_s8(vorrq_s8(v0lz, qhl), s16b); + const int8x16_t v0hf = vsubq_s8(vorrq_s8(v0hz, qhh), s16b); // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + const int8x16_t v1l = vld1q_s8(y0->qs); + const int8x16_t v1h = vld1q_s8(y0->qs + 16); -#if defined(__ARM_FEATURE_DOTPROD) - sumv0 = vmlaq_n_f32(sumv0, vaddq_f32( - vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)), - vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d); + const float x0d = GGML_FP16_TO_FP32(x0->d); - sumv1 = vmlaq_n_f32(sumv1, vaddq_f32( - vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)), - vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d); +#if defined(__ARM_FEATURE_DOTPROD) + sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), v0lf, v1l), + vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d); #else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h)); + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h)); const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - sumv0 = vmlaq_n_f32(sumv0, vaddq_f32( - vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)), - vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d); - - sumv1 = vmlaq_n_f32(sumv1, vaddq_f32( - vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)), - vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d); + sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d); #endif } - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); + *s = vaddvq_f32(sumv); #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -2821,17 +3378,12 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * // Main loop for (int i = 0; i < nb; i++) { /* Compute combined scale for the block */ - const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d)); - const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d)); - const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d)); - - __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs); - __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs); - __m256i bx = _mm256_set_m128i(bx1, bx0); + const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d)); - // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. - const __m256i off = _mm256_set1_epi8(8); - bx = _mm256_sub_epi8(bx, off); + __m256i bx = bytes_from_nibbles_32(x[i].qs); + __m256i bxhi = bytes_from_bits_32(x[i].qh); + bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0)); + bx = _mm256_or_si256(bx, bxhi); __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); @@ -2846,104 +3398,110 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * // scalar float sumf = 0.0; for (int i = 0; i < nb; i++) { - const uint8_t * restrict x0 = x[2*i + 0].qs; - const uint8_t * restrict x1 = x[2*i + 1].qs; + const uint8_t * restrict x0 = x[i].qs; const int8_t * restrict y0 = y[i].qs; - const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d); - const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d); + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); - int sumi_0 = 0; - int sumi_1 = 0; + const float d = GGML_FP16_TO_FP32(x[i].d); - for (int j = 0; j < QK8_0/4; j++) { - const uint8_t v0 = x0[j]; - const uint8_t v1 = x1[j]; + int sxy = 0; - const int i0_0 = (int8_t) (v0 & 0xf) - 8; - const int i1_0 = (int8_t) (v0 >> 4) - 8; + for (int j = 0; j < QK8_0/2; j++) { + const uint8_t v0 = x0[j]; - const int i0_1 = (int8_t) (v1 & 0xf) - 8; - const int i1_1 = (int8_t) (v1 >> 4) - 8; + const int x0_0h = ((qh & (1 << (2*j + 0))) >> (2*j + 0)) << 4; + const int x1_0h = ((qh & (1 << (2*j + 1))) >> (2*j + 1)) << 4; - const int i2_0 = y0[2*j + 0]; - const int i3_0 = y0[2*j + 1]; + const int x0_0 = ((v0 & 0x0F) | x0_0h) - 16; + const int x1_0 = ((v0 >> 4) | x1_0h) - 16; - const int i2_1 = y0[2*(j + QK8_0/4) + 0]; - const int i3_1 = y0[2*(j + QK8_0/4) + 1]; + const int y0_0 = y0[2*j + 0]; + const int y1_0 = y0[2*j + 1]; - sumi_0 += i0_0*i2_0 + i1_0*i3_0; - sumi_1 += i0_1*i2_1 + i1_1*i3_1; + sxy += x0_0*y0_0 + x1_0*y1_0; } - sumf += (d0 * y[i].d) * sumi_0; - sumf += (d1 * y[i].d) * sumi_1; + sumf += (d*sxy)*y[i].d; } *s = sumf; #endif } -static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int nb = n / QK8_0; +static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int nb = n / QK8_1; - assert(n % QK8_0 == 0); + assert(n % QK8_1 == 0); assert(nb % 2 == 0); - assert(QK8_0 == 2*QK4_2); + assert(QK8_1 == QK5_1); - const block_q4_3 * restrict x = vx; - const block_q8_0 * restrict y = vy; + const block_q5_1 * restrict x = vx; + const block_q8_1 * restrict y = vy; #if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); + float32x4_t sumv = vdupq_n_f32(0.0f); - float summs0 = 0.0f; - float summs1 = 0.0f; + float summs = 0.0f; + + uint64_t tmp[4]; for (int i = 0; i < nb; ++i) { - const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0]; - const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1]; + const block_q5_1 * restrict x0 = &x[i]; + const block_q8_1 * restrict y0 = &y[i]; - const block_q8_0 * restrict y0 = &y[i + 0]; + summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1); - summs0 += GGML_FP16_TO_FP32(x0_0->m) * y0->s0; - summs1 += GGML_FP16_TO_FP32(x0_1->m) * y0->s1; + // extract the 5th bit + uint32_t qh; + memcpy(&qh, x0->qh, sizeof(qh)); - const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs)); + tmp[0] = table_b2b_u[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b_u[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b_u[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b_u[(qh >> 24) ]; + + const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0)); + const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2)); + + const uint8x16_t v0 = vld1q_u8(x0->qs); // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0xf))); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, vdupq_n_u8(0x0F))); + const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4)); // interleave - const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h); - const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h); + const int8x16_t v0lz = vzip1q_s8(v0l, v0h); + const int8x16_t v0hz = vzip2q_s8(v0l, v0h); + + // add + const int8x16_t v0lf = vorrq_s8(v0lz, qhl); + const int8x16_t v0hf = vorrq_s8(v0hz, qhh); // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1l = vld1q_s8(y0->qs); + const int8x16_t v1h = vld1q_s8(y0->qs + 16); - const float x0_0d = GGML_FP16_TO_FP32(x0_0->d); - const float x0_1d = GGML_FP16_TO_FP32(x0_1->d); + const float x0d = GGML_FP16_TO_FP32(x0->d); #if defined(__ARM_FEATURE_DOTPROD) - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d); + sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), v0lf, v1l), + vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d); #else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h)); + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h)); const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph0), x0_1d*y0->d); + sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d); #endif } - *s = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs0 + summs1; + *s = vaddvq_f32(sumv) + summs; #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -2951,16 +3509,14 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * // Main loop for (int i = 0; i < nb; i++) { - const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d)); - const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d)); - const __m256 dx = _mm256_set_m128(d1, d0); + const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)); - summs += GGML_FP16_TO_FP32(x[2*i + 0].m) * y[i].s0 - + GGML_FP16_TO_FP32(x[2*i + 1].m) * y[i].s1; + summs += GGML_FP16_TO_FP32(x[i].m) * (y[i].s0 + y[i].s1); - const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs); - const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs); - const __m256i bx = _mm256_set_m128i(bx1, bx0); + __m256i bx = bytes_from_nibbles_32(x[i].qs); + __m256i bxhi = bytes_from_bits_32(x[i].qh); + bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10)); + bx = _mm256_or_si256(bx, bxhi); const __m256 dy = _mm256_broadcast_ss(&y[i].d); const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); @@ -2972,47 +3528,127 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * *s = hsum_float_8(acc) + summs; #else - // scalar float sumf = 0.0; + for (int i = 0; i < nb; i++) { - const uint8_t * restrict x0 = x[2*i + 0].qs; - const uint8_t * restrict x1 = x[2*i + 1].qs; + const uint8_t * restrict x0 = x[i].qs; const int8_t * restrict y0 = y[i].qs; - const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d); - const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m); - const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d); - const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m); + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); - int sxy_0 = 0; - int sxy_1 = 0; + const float d = GGML_FP16_TO_FP32(x[i].d); + const float m = GGML_FP16_TO_FP32(x[i].m); - for (int j = 0; j < QK8_0/4; j++) { + int sxy = 0; + + for (int j = 0; j < QK8_1/2; j++) { const uint8_t v0 = x0[j]; - const uint8_t v1 = x1[j]; - const int x0_0 = v0 & 0xf; - const int x1_0 = v0 >> 4; + const int x0_0h = ((qh & (1 << (2*j + 0))) >> (2*j + 0)) << 4; + const int x1_0h = ((qh & (1 << (2*j + 1))) >> (2*j + 1)) << 4; - const int x0_1 = v1 & 0xf; - const int x1_1 = v1 >> 4; + const int x0_0 = (v0 & 0x0F) | x0_0h; + const int x1_0 = (v0 >> 4) | x1_0h; const int y0_0 = y0[2*j + 0]; const int y1_0 = y0[2*j + 1]; - const int y0_1 = y0[2*(j + QK8_0/4) + 0]; - const int y1_1 = y0[2*(j + QK8_0/4) + 1]; - - sxy_0 += x0_0*y0_0 + x1_0*y1_0; - sxy_1 += x0_1*y0_1 + x1_1*y1_1; + sxy += x0_0*y0_0 + x1_0*y1_0; } - sumf += (d0*sxy_0 + d1*sxy_1)*y[i].d + m0*y[i].s0 + m1*y[i].s1; + sumf += (d*sxy)*y[i].d + m*(y[i].s0 + y[i].s1); } + *s = sumf; #endif } +static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int nb = n / QK8_0; + + assert(n % QK8_0 == 0); + assert(nb % 2 == 0); + assert(QK8_0 == QK8_0); + + const block_q8_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; i += 2) { + const block_q8_0 * restrict x0 = &x[i + 0]; + const block_q8_0 * restrict x1 = &x[i + 1]; + const block_q8_0 * restrict y0 = &y[i + 0]; + const block_q8_0 * restrict y1 = &y[i + 1]; + + const int8x16_t x0_0 = vld1q_s8(x0->qs); + const int8x16_t x0_1 = vld1q_s8(x0->qs + 16); + const int8x16_t x1_0 = vld1q_s8(x1->qs); + const int8x16_t x1_1 = vld1q_s8(x1->qs + 16); + + // load y + const int8x16_t y0_0 = vld1q_s8(y0->qs); + const int8x16_t y0_1 = vld1q_s8(y0->qs + 16); + const int8x16_t y1_0 = vld1q_s8(y1->qs); + const int8x16_t y1_1 = vld1q_s8(y1->qs + 16); + +#if defined(__ARM_FEATURE_DOTPROD) + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), x0_0, y0_0), + vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), x0->d*y0->d); + + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), x1_0, y1_0), + vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), x1->d*y1->d); + +#else + const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0)); + const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0)); + const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1)); + const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1)); + + const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0)); + const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0)); + const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1)); + const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1)); + + const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1)); + const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3)); + const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1)); + const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), x0->d*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), x1->d*y1->d); +#endif + } + + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#else + // scalar + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + const int8_t * restrict x0 = x[i].qs; + const int8_t * restrict y0 = y[i].qs; + + int sumi = 0; + + for (int j = 0; j < QK8_0; j++) { + const int v0 = x0[j]; + const int v1 = y0[j]; + + sumi += v0*v1; + } + + sumf += (x[i].d*y[i].d)*sumi; + } + + *s = sumf; +#endif +} // compute GGML_VEC_DOT_UNROLL dot products at once // xs - x row stride in bytes @@ -3210,6 +3846,14 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { #endif } +inline static void ggml_vec_sum_ggf(const int n, ggml_float * s, const float * x) { + ggml_float sum = 0.0; + for (int i = 0; i < n; ++i) { + sum += (ggml_float)x[i]; + } + *s = sum; +} + inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { #ifndef GGML_USE_ACCELERATE float max = -INFINITY; @@ -3262,12 +3906,15 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_1] = QK4_1, [GGML_TYPE_Q4_2] = QK4_2, [GGML_TYPE_Q4_3] = QK4_3, + [GGML_TYPE_Q5_0] = QK5_0, + [GGML_TYPE_Q5_1] = QK5_1, [GGML_TYPE_Q8_0] = QK8_0, + [GGML_TYPE_Q8_1] = QK8_1, [GGML_TYPE_I8] = 1, [GGML_TYPE_I16] = 1, [GGML_TYPE_I32] = 1, }; -static_assert(GGML_TYPE_COUNT == 10, "GGML_BLCK_SIZE is outdated"); +static_assert(GGML_TYPE_COUNT == 13, "GGML_BLCK_SIZE is outdated"); static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = sizeof(float), @@ -3276,12 +3923,15 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_1] = sizeof(block_q4_1), [GGML_TYPE_Q4_2] = sizeof(block_q4_2), [GGML_TYPE_Q4_3] = sizeof(block_q4_3), + [GGML_TYPE_Q5_0] = sizeof(block_q5_0), + [GGML_TYPE_Q5_1] = sizeof(block_q5_1), [GGML_TYPE_Q8_0] = sizeof(block_q8_0), + [GGML_TYPE_Q8_1] = sizeof(block_q8_1), [GGML_TYPE_I8] = sizeof(int8_t), [GGML_TYPE_I16] = sizeof(int16_t), [GGML_TYPE_I32] = sizeof(int32_t), }; -static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_SIZE is outdated"); +static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_SIZE is outdated"); static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { @@ -3291,12 +3941,15 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_1] = "q4_1", [GGML_TYPE_Q4_2] = "q4_2", [GGML_TYPE_Q4_3] = "q4_3", + [GGML_TYPE_Q5_0] = "q5_0", + [GGML_TYPE_Q5_1] = "q5_1", [GGML_TYPE_Q8_0] = "q8_0", + [GGML_TYPE_Q8_1] = "q8_1", [GGML_TYPE_I8] = "i8", [GGML_TYPE_I16] = "i16", [GGML_TYPE_I32] = "i32", }; -static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_NAME is outdated"); +static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_NAME is outdated"); static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = false, @@ -3305,12 +3958,15 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_1] = true, [GGML_TYPE_Q4_2] = true, [GGML_TYPE_Q4_3] = true, + [GGML_TYPE_Q5_0] = true, + [GGML_TYPE_Q5_1] = true, [GGML_TYPE_Q8_0] = true, + [GGML_TYPE_Q8_1] = true, [GGML_TYPE_I8] = false, [GGML_TYPE_I16] = false, [GGML_TYPE_I32] = false, }; -static_assert(GGML_TYPE_COUNT == 10, "GGML_IS_QUANTIZED is outdated"); +static_assert(GGML_TYPE_COUNT == 13, "GGML_IS_QUANTIZED is outdated"); static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "NONE", @@ -6522,6 +7178,9 @@ static void ggml_compute_forward_add( case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: { ggml_compute_forward_add_q_f32(params, src0, src1, dst); } break; @@ -6779,15 +7438,20 @@ static void ggml_compute_forward_sum_f32( const size_t nb02 = src0->nb[2]; const size_t nb03 = src0->nb[3]; + ggml_float sum = 0; + ggml_float row_sum = 0; + for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = 0; i01 < ne01; i01++) { - ggml_vec_sum_f32(ne00, - (float *) (dst->data), + ggml_vec_sum_ggf(ne00, + &row_sum, (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); + sum += row_sum; } } } + ((float *) dst->data)[0] = sum; } static void ggml_compute_forward_sum( @@ -7944,6 +8608,7 @@ static void ggml_compute_forward_mul_mat_q_f32( const enum ggml_type type = src0->type; quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot; vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q; + enum ggml_type const vec_dot_type = quantize_fns[type].vec_dot_type; // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]); @@ -8003,6 +8668,15 @@ static void ggml_compute_forward_mul_mat_q_f32( else if (type == GGML_TYPE_Q4_3) { dequantize_row_q_cuda = dequantize_row_q4_3_cuda; } + else if (type == GGML_TYPE_Q5_0) { + dequantize_row_q_cuda = dequantize_row_q5_0_cuda; + } + else if (type == GGML_TYPE_Q5_1) { + dequantize_row_q_cuda = dequantize_row_q5_1_cuda; + } + else if (type == GGML_TYPE_Q8_0) { + dequantize_row_q_cuda = dequantize_row_q8_0_cuda; + } else { GGML_ASSERT(false); } @@ -8077,7 +8751,7 @@ static void ggml_compute_forward_mul_mat_q_f32( if (params->type == GGML_TASK_INIT) { char * wdata = params->wdata; - const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0]; + const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type]; for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { @@ -8108,7 +8782,7 @@ static void ggml_compute_forward_mul_mat_q_f32( const int ir1 = MIN(ir0 + dr, nr); void * wdata = params->wdata; - const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0]; + const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type]; for (int ir = ir0; ir < ir1; ++ir) { // src0 indices @@ -8158,7 +8832,10 @@ static void ggml_compute_forward_mul_mat( case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: { ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst); } break; @@ -8387,7 +9064,10 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: { ggml_compute_forward_get_rows_q(params, src0, src1, dst); } break; @@ -8525,6 +9205,7 @@ static void ggml_compute_forward_soft_max_f32( uint16_t scvt; for (int i = 0; i < nc; i++) { + //printf("p[%3d] = %8.4f\n", i, p[i]); if (p[i] == -INFINITY) { p[i] = 0.0f; } else { @@ -10909,7 +11590,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } else #endif { - cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0]; + const enum ggml_type type_q = quantize_fns[node->src0->type].vec_dot_type; + cur = GGML_TYPE_SIZE[type_q]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[type_q]; } } else { GGML_ASSERT(false); @@ -12097,7 +12779,7 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * for (int i = 0; i < nb; i++) { for (int l = 0; l < QK4_0; l += 2) { - const uint8_t vi0 = y[i].qs[l/2] & 0xF; + const uint8_t vi0 = y[i].qs[l/2] & 0x0F; const uint8_t vi1 = y[i].qs[l/2] >> 4; hist[vi0]++; @@ -12120,7 +12802,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * for (int i = 0; i < nb; i++) { for (int l = 0; l < QK4_1; l += 2) { - const uint8_t vi0 = y[i].qs[l/2] & 0xF; + const uint8_t vi0 = y[i].qs[l/2] & 0x0F; const uint8_t vi1 = y[i].qs[l/2] >> 4; hist[vi0]++; @@ -12139,12 +12821,11 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * for (int j = 0; j < n; j += k) { block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2; - //quantize_row_q4_2_reference(src + j, y, k); - quantize_row_q4_2_rmse(src + j, y, k); + quantize_row_q4_2_reference(src + j, y, k); for (int i = 0; i < nb; i++) { for (int l = 0; l < QK4_2; l += 2) { - const uint8_t vi0 = y[i].qs[l/2] & 0xF; + const uint8_t vi0 = y[i].qs[l/2] & 0x0F; const uint8_t vi1 = y[i].qs[l/2] >> 4; hist[vi0]++; @@ -12167,7 +12848,7 @@ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * for (int i = 0; i < nb; i++) { for (int l = 0; l < QK4_3; l += 2) { - const uint8_t vi0 = y[i].qs[l/2] & 0xF; + const uint8_t vi0 = y[i].qs[l/2] & 0x0F; const uint8_t vi1 = y[i].qs[l/2] >> 4; hist[vi0]++; @@ -12179,6 +12860,87 @@ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * return (n/QK4_3*sizeof(block_q4_3)); } +size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK5_0 == 0); + const int nb = k / QK5_0; + + for (int j = 0; j < n; j += k) { + block_q5_0 * restrict y = (block_q5_0 *)dst + j/QK5_0; + + quantize_row_q5_0_reference(src + j, y, k); + + for (int i = 0; i < nb; i++) { + uint32_t qh; + memcpy(&qh, &y[i].qh, sizeof(qh)); + + for (int l = 0; l < QK5_0; l += 2) { + const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + + // cast to 16 bins + const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2; + const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2; + + hist[vi0]++; + hist[vi1]++; + } + } + } + + return (n/QK5_0*sizeof(block_q5_0)); +} + +size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK5_1 == 0); + const int nb = k / QK5_1; + + for (int j = 0; j < n; j += k) { + block_q5_1 * restrict y = (block_q5_1 *)dst + j/QK5_1; + + quantize_row_q5_1_reference(src + j, y, k); + + for (int i = 0; i < nb; i++) { + uint32_t qh; + memcpy(&qh, &y[i].qh, sizeof(qh)); + + for (int l = 0; l < QK5_1; l += 2) { + const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + + // cast to 16 bins + const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2; + const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2; + + hist[vi0]++; + hist[vi1]++; + } + } + } + + return (n/QK5_1*sizeof(block_q5_1)); +} + +size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + for (int j = 0; j < n; j += k) { + block_q8_0 * restrict y = (block_q8_0 *)dst + j/QK8_0; + + quantize_row_q8_0_reference(src + j, y, k); + + for (int i = 0; i < nb; i++) { + for (int l = 0; l < QK8_0; ++l) { + const int8_t vi = y[i].qs[l]; + + hist[vi/16 + 8]++; + } + } + } + + return (n/QK8_0*sizeof(block_q8_0)); +} + size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) { size_t result = 0; switch (type) { @@ -12206,6 +12968,24 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i block_q4_3 * block = (block_q4_3*)dst + start / QK4_3; result = ggml_quantize_q4_3(src + start, block, n, n, hist); } break; + case GGML_TYPE_Q5_0: + { + GGML_ASSERT(start % QK5_0 == 0); + block_q5_0 * block = (block_q5_0*)dst + start / QK5_0; + result = ggml_quantize_q5_0(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q5_1: + { + GGML_ASSERT(start % QK5_1 == 0); + block_q5_1 * block = (block_q5_1*)dst + start / QK5_1; + result = ggml_quantize_q5_1(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q8_0: + { + GGML_ASSERT(start % QK8_0 == 0); + block_q8_0 * block = (block_q8_0*)dst + start / QK8_0; + result = ggml_quantize_q8_0(src + start, block, n, n, hist); + } break; default: assert(false); }