]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : sync latest ggml lib
authorGeorgi Gerganov <redacted>
Sun, 25 Jun 2023 11:22:21 +0000 (14:22 +0300)
committerGeorgi Gerganov <redacted>
Sun, 25 Jun 2023 11:30:44 +0000 (14:30 +0300)
examples/common-ggml.cpp
examples/common.cpp
examples/common.h
examples/main/main.cpp
examples/quantize/quantize.cpp
ggml-cuda.cu
ggml-cuda.h
ggml-opencl.h
ggml.c
ggml.h
whisper.cpp

index 9215dbeabd195339408f5ccb538f1a2dbcde5f3e..33ae03ae10ff419dd18b336e5e7d9406afc94d36 100644 (file)
@@ -52,6 +52,11 @@ bool ggml_common_quantize_0(
         case GGML_FTYPE_ALL_F32:
         case GGML_FTYPE_MOSTLY_F16:
         case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16:
+        case GGML_FTYPE_MOSTLY_Q2_K:
+        case GGML_FTYPE_MOSTLY_Q3_K:
+        case GGML_FTYPE_MOSTLY_Q4_K:
+        case GGML_FTYPE_MOSTLY_Q5_K:
+        case GGML_FTYPE_MOSTLY_Q6_K:
                 {
                     fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype);
                     return false;
@@ -187,6 +192,12 @@ bool ggml_common_quantize_0(
                 case GGML_TYPE_I16:
                 case GGML_TYPE_I32:
                 case GGML_TYPE_Q8_1:
+                case GGML_TYPE_Q2_K:
+                case GGML_TYPE_Q3_K:
+                case GGML_TYPE_Q4_K:
+                case GGML_TYPE_Q5_K:
+                case GGML_TYPE_Q6_K:
+                case GGML_TYPE_Q8_K:
                 case GGML_TYPE_COUNT:
                     {
                         fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype));
index 76da30d9d02b270a16eb1034510dae9c064907bf..fe00278c2584e1c49d243420d3f69bab8c2262ca 100644 (file)
@@ -6,13 +6,21 @@
 #include "dr_wav.h"
 
 #include <cmath>
+#include <cstring>
 #include <fstream>
 #include <regex>
+#include <locale>
+#include <codecvt>
+#include <sstream>
 
 #ifndef M_PI
 #define M_PI 3.14159265358979323846
 #endif
 
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
 bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
     for (int i = 1; i < argc; i++) {
         std::string arg = argv[i];
@@ -52,7 +60,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
             if (params.prompt.back() == '\n') {
                 params.prompt.pop_back();
             }
-        } else {
+        } else if (arg == "-tt" || arg == "--token_test") {
+            params.token_test = argv[++i];
+        }
+        else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
             gpt_print_usage(argc, argv, params);
             exit(0);
@@ -73,6 +84,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stderr, "                        prompt to start generation with (default: random)\n");
     fprintf(stderr, "  -f FNAME, --file FNAME\n");
     fprintf(stderr, "                        load prompt from a file\n");
+    fprintf(stderr, "  -tt TOKEN_TEST, --token_test TOKEN_TEST\n");
+    fprintf(stderr, "                        test tokenization\n");
     fprintf(stderr, "  -n N, --n_predict N   number of tokens to predict (default: %d)\n", params.n_predict);
     fprintf(stderr, "  --top_k N             top-k sampling (default: %d)\n", params.top_k);
     fprintf(stderr, "  --top_p N             top-p sampling (default: %.1f)\n", params.top_p);
@@ -117,6 +130,10 @@ std::string replace(const std::string & s, const std::string & from, const std::
     return result;
 }
 
+void gpt_vocab::add_special_token(const std::string & token) {
+    special_tokens.push_back(token);
+}
+
 std::map<std::string, int32_t> json_parse(const std::string & fname) {
     std::map<std::string, int32_t> result;
 
@@ -208,8 +225,28 @@ std::map<std::string, int32_t> json_parse(const std::string & fname) {
     return result;
 }
 
-void gpt_vocab::add_special_token(const std::string & token) {
-    special_tokens.push_back(token);
+std::string convert_to_utf8(const std::wstring & input) {
+    std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
+    return converter.to_bytes(input);
+}
+
+
+std::wstring convert_to_wstring(const std::string & input) {
+    std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
+    return converter.from_bytes(input);
+}
+
+void gpt_split_words(std::string str, std::vector<std::string>& words) {
+    const std::string pattern = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
+    const std::regex re(pattern);
+    std::smatch m;
+
+    while (std::regex_search(str, m, re)) {
+        for (auto x : m) {
+            words.push_back(x);
+        }
+        str = m.suffix();
+    }
 }
 
 std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
@@ -218,63 +255,52 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri
     // first split the text into words
     {
         std::string str = text;
-        std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
 
         // Generate the subpattern from the special_tokens vector if it's not empty
         if (!vocab.special_tokens.empty()) {
+            const std::regex escape(R"([\[\\\^\$\.\|\?\*\+\(\)\{\}])");
             std::string special_tokens_subpattern;
             for (const auto & token : vocab.special_tokens) {
                 if (!special_tokens_subpattern.empty()) {
                     special_tokens_subpattern += "|";
                 }
-                special_tokens_subpattern += token;
+                special_tokens_subpattern += std::regex_replace(token, escape, R"(\$&)");
             }
 
-            // Modify the regex pattern with the generated special tokens subpattern
-            pat = special_tokens_subpattern + "|" + pat;
-        }
-
-        std::regex re(pat);
-        std::smatch m;
-
-        while (std::regex_search(str, m, re)) {
-            for (auto x : m) {
-                words.push_back(x);
+            std::regex re(special_tokens_subpattern);
+            std::smatch m;
+            // Split the text by special tokens.
+            while (std::regex_search(str, m, re)) {
+                // Split the substrings in-between special tokens into words.
+                gpt_split_words(m.prefix(), words);
+                // Add matched special tokens as words.
+                for (auto x : m) {
+                    words.push_back(x);
+                }
+                str = m.suffix();
             }
-            str = m.suffix();
+            // Remaining text without special tokens will be handled below.
         }
+
+        gpt_split_words(str, words);
     }
 
-    // find the longest tokens that form the words:
+    // find the longest token that forms each word in words:
     std::vector<gpt_vocab::id> tokens;
     for (const auto & word : words) {
-        if (word.size() == 0) continue;
-
-        int i = 0;
-        int n = word.size();
-        while (i < n) {
-            int j = n;
-            while (j > i) {
-                auto it = vocab.token_to_id.find(word.substr(i, j-i));
-                if (it != vocab.token_to_id.end()) {
+        for (int i = 0; i < (int) word.size(); ){
+            for (int j = word.size() - 1; j >= i; j--){
+                auto cand = word.substr(i, j-i+1);
+                auto it = vocab.token_to_id.find(cand);
+                if (it != vocab.token_to_id.end()){ // word.substr(i, j-i+1) in vocab
                     tokens.push_back(it->second);
-                    i = j;
-                    j = n;
+                    i = j + 1;
                     break;
                 }
-                --j;
-            }
-            if (i == n) {
-                break;
-            }
-            if (j == i) {
-                auto sub = word.substr(i, 1);
-                if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
-                    tokens.push_back(vocab.token_to_id.at(sub));
-                } else {
-                    fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
+                else if (j == i){ // word.substr(i, 1) has no matching
+                    fprintf(stderr, "%s: unknown token '%s'\n", __func__, word.substr(i, 1).data());
+                    i++;
                 }
-                ++i;
             }
         }
     }
@@ -282,6 +308,70 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri
     return tokens;
 }
 
+std::vector<gpt_vocab::id> parse_tokens_from_string(const std::string& input, char delimiter) {
+    std::vector<gpt_vocab::id> output;
+    std::stringstream ss(input);
+    std::string token;
+
+    while (std::getline(ss, token, delimiter)) {
+        output.push_back(std::stoi(token));
+    }
+
+    return output;
+}
+
+std::map<std::string, std::vector<gpt_vocab::id>> extract_tests_from_file(const std::string & fpath_test){
+    if (fpath_test.empty()){
+        fprintf(stderr, "%s : No test file found.\n", __func__);
+        return std::map<std::string, std::vector<gpt_vocab::id>>();
+    }
+
+    std::map<std::string, std::vector<gpt_vocab::id>> tests;
+
+    auto fin = std::ifstream(fpath_test, std::ios_base::in);
+    const char * delimeter = " => ";
+    const char del_tok = ',';
+    std::string line;
+    while (std::getline(fin, line)) {
+        size_t delimiterPos = line.find(delimeter);
+        if (delimiterPos != std::string::npos) {
+            std::string text = line.substr(0, delimiterPos);
+            std::string s_tokens = line.substr(delimiterPos + std::strlen(delimeter));
+            tests[text] = parse_tokens_from_string(s_tokens, del_tok);
+        }
+    }
+    return tests;
+}
+
+void test_gpt_tokenizer(gpt_vocab & vocab, const std::string & fpath_test){
+    std::map<std::string, std::vector<gpt_vocab::id>> tests = extract_tests_from_file(fpath_test);
+
+    size_t n_fails = 0;
+
+    for (const auto & test : tests) {
+        std::vector<gpt_vocab::id> tokens = gpt_tokenize(vocab, test.first);
+
+        if (tokens != test.second){
+            n_fails++;
+
+            // print out failure cases
+            fprintf(stderr, "%s : failed test: '%s'\n", __func__, test.first.c_str());
+            fprintf(stderr, "%s : tokens in hf:   ", __func__);
+            for (const auto & t : test.second) {
+                fprintf(stderr, "%s(%d), ", vocab.id_to_token[t].c_str(), t);
+            }
+            fprintf(stderr, "\n");
+            fprintf(stderr, "%s : tokens in ggml: ", __func__);
+            for (const auto & t : tokens) {
+                fprintf(stderr, "%s(%d), ", vocab.id_to_token[t].c_str(), t);
+            }
+            fprintf(stderr, "\n");
+        }
+    }
+
+    fprintf(stderr, "%s : %zu tests failed out of %zu tests.\n", __func__, n_fails, tests.size());
+}
+
 bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
     printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
 
@@ -381,6 +471,122 @@ gpt_vocab::id gpt_sample_top_k_top_p(
     return logits_id[idx].second;
 }
 
+gpt_vocab::id gpt_sample_top_k_top_p_repeat(
+        const gpt_vocab & vocab,
+        const float * logits,
+        const int32_t * last_n_tokens_data,
+        size_t last_n_tokens_data_size,
+        int    top_k,
+        double top_p,
+        double temp,
+        int repeat_last_n,
+        float repeat_penalty,
+        std::mt19937 & rng) {
+
+    int n_logits = vocab.id_to_token.size();
+
+    const auto * plogits = logits;
+
+    const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_data_size);
+
+    if (temp <= 0) {
+        // select the token with the highest logit directly
+        float max_logit = plogits[0];
+        gpt_vocab::id max_id = 0;
+
+        for (int i = 1; i < n_logits; ++i) {
+            if (plogits[i] > max_logit) {
+                max_logit = plogits[i];
+                max_id = i;
+            }
+        }
+        return max_id;
+    }
+
+
+    std::vector<std::pair<double, gpt_vocab::id>> logits_id;
+    logits_id.reserve(n_logits);
+
+    {
+        const float scale = 1.0f/temp;
+        for (int i = 0; i < n_logits; ++i) {
+            // repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
+            // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
+            if (repeat_last_n > 0 && std::find(last_n_tokens.end()-repeat_last_n, last_n_tokens.end(), i) != last_n_tokens.end()) {
+                // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
+                if (plogits[i] < 0.0f) {
+                    logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
+                } else {
+                    logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
+                }
+            } else {
+                logits_id.push_back(std::make_pair(plogits[i]*scale, i));
+            }
+        }
+    }
+
+    // find the top K tokens
+    std::partial_sort(
+            logits_id.begin(),
+            logits_id.begin() + top_k, logits_id.end(),
+            [](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
+        return a.first > b.first;
+    });
+
+    logits_id.resize(top_k);
+
+    double maxl = -INFINITY;
+    for (const auto & kv : logits_id) {
+        maxl = std::max(maxl, kv.first);
+    }
+
+    // compute probs for the top K tokens
+    std::vector<double> probs;
+    probs.reserve(logits_id.size());
+
+    double sum = 0.0;
+    for (const auto & kv : logits_id) {
+        double p = exp(kv.first - maxl);
+        probs.push_back(p);
+        sum += p;
+    }
+
+    // normalize the probs
+    for (auto & p : probs) {
+        p /= sum;
+    }
+
+    if (top_p < 1.0f) {
+        double cumsum = 0.0f;
+        for (int i = 0; i < top_k; i++) {
+            cumsum += probs[i];
+            if (cumsum >= top_p) {
+                top_k = i + 1;
+                probs.resize(top_k);
+                logits_id.resize(top_k);
+                break;
+            }
+        }
+
+        cumsum = 1.0/cumsum;
+        for (int i = 0; i < (int) probs.size(); i++) {
+            probs[i] *= cumsum;
+        }
+    }
+
+//    printf("\n");
+//    for (int i = 0; i < (int) probs.size(); i++) {
+//    for (int i = 0; i < 10; i++) {
+//        printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
+//    }
+
+    std::discrete_distribution<> dist(probs.begin(), probs.end());
+    int idx = dist(rng);
+
+    return logits_id[idx].second;
+
+}
+
 bool read_wav(const std::string & fname, std::vector<float>& pcmf32, std::vector<std::vector<float>>& pcmf32s, bool stereo) {
     drwav wav;
     std::vector<uint8_t> wav_data; // used for pipe input from stdin
index 29d0792af2d356945fb19b70709473963de3fccb..7e9b867d3d6f1f8af3a00140a153aa321c68d72f 100644 (file)
@@ -26,8 +26,9 @@ struct gpt_params {
 
     int32_t n_batch = 8; // batch size for prompt processing
 
-    std::string model = "models/gpt-2-117M/ggml-model.bin"; // model path
-    std::string prompt;
+    std::string model      = "models/gpt-2-117M/ggml-model.bin"; // model path
+    std::string prompt     = "";
+    std::string token_test = "";
 };
 
 bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
@@ -61,6 +62,12 @@ struct gpt_vocab {
 // poor-man's JSON parsing
 std::map<std::string, int32_t> json_parse(const std::string & fname);
 
+std::string convert_to_utf8(const std::wstring & input);
+
+std::wstring convert_to_wstring(const std::string & input);
+
+void gpt_split_words(std::string str, std::vector<std::string>& words);
+
 // split text into tokens
 //
 // ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
@@ -73,6 +80,15 @@ std::map<std::string, int32_t> json_parse(const std::string & fname);
 //
 std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text);
 
+// test outputs of gpt_tokenize
+//
+//   - compare with tokens generated by the huggingface tokenizer
+//   - test cases are chosen based on the model's main language (under 'prompt' directory)
+//   - if all sentences are tokenized identically, print 'All tests passed.'
+//   - otherwise, print sentence, huggingface tokens, ggml tokens
+//
+void test_gpt_tokenizer(gpt_vocab & vocab, const std::string & fpath_test);
+
 // load the tokens from encoder.json
 bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
 
@@ -92,6 +108,18 @@ gpt_vocab::id gpt_sample_top_k_top_p(
         double temp,
         std::mt19937 & rng);
 
+gpt_vocab::id gpt_sample_top_k_top_p_repeat(
+        const gpt_vocab & vocab,
+        const float * logits,
+        const int32_t * last_n_tokens_data,
+        size_t last_n_tokens_data_size,
+        int    top_k,
+        double top_p,
+        double temp,
+        int repeat_last_n,
+        float repeat_penalty,
+        std::mt19937 & rng);
+
 //
 // Audio utils
 //
index 9b1958620b6f13ec082a817d236914bf74436cfb..07a7591fe2151dc3c3b649c545c8f01b504aa288 100644 (file)
 #include <vector>
 #include <cstring>
 
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
 // Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9]
 // Lowest is red, middle is yellow, highest is green.
 const std::vector<std::string> k_colors = {
@@ -148,7 +152,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-f"    || arg == "--file")           { params.fname_inp.emplace_back(argv[++i]); }
         else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
-            return false;
+            whisper_print_usage(argc, argv, params);
+            exit(0);
         }
     }
 
@@ -423,13 +428,13 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
         indent++;
     };
 
-    auto end_arr = [&](bool end = false) {
+    auto end_arr = [&](bool end) {
         indent--;
         doindent();
         fout << (end ? "]\n" : "},\n");
     };
 
-    auto start_obj = [&](const char *name = nullptr) {
+    auto start_obj = [&](const char *name) {
         doindent();
         if (name) {
             fout << "\"" << name << "\": {\n";
@@ -439,7 +444,7 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
         indent++;
     };
 
-    auto end_obj = [&](bool end = false) {
+    auto end_obj = [&](bool end) {
         indent--;
         doindent();
         fout << (end ? "}\n" : "},\n");
@@ -450,24 +455,24 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
         fout << "\"" << name << "\": ";
     };
 
-    auto value_s = [&](const char *name, const char *val, bool end = false) {
+    auto value_s = [&](const char *name, const char *val, bool end) {
         start_value(name);
         char * val_escaped = escape_double_quotes_and_backslashes(val);
         fout << "\"" << val_escaped << (end ? "\"\n" : "\",\n");
         free(val_escaped);
     };
 
-    auto end_value = [&](bool end = false) {
+    auto end_value = [&](bool end) {
         fout << (end ? "\n" : ",\n");
     };
 
-    auto value_i = [&](const char *name, const int64_t val, bool end = false) {
+    auto value_i = [&](const char *name, const int64_t val, bool end) {
         start_value(name);
         fout << val;
         end_value(end);
     };
 
-    auto value_b = [&](const char *name, const bool val, bool end = false) {
+    auto value_b = [&](const char *name, const bool val, bool end) {
         start_value(name);
         fout << (val ? "true" : "false");
         end_value(end);
@@ -479,35 +484,35 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
     }
 
     fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
-    start_obj();
-        value_s("systeminfo", whisper_print_system_info());
+    start_obj(nullptr);
+        value_s("systeminfo", whisper_print_system_info(), false);
         start_obj("model");
-            value_s("type", whisper_model_type_readable(ctx));
-            value_b("multilingual", whisper_is_multilingual(ctx));
-            value_i("vocab", whisper_model_n_vocab(ctx));
+            value_s("type", whisper_model_type_readable(ctx), false);
+            value_b("multilingual", whisper_is_multilingual(ctx), false);
+            value_i("vocab", whisper_model_n_vocab(ctx), false);
             start_obj("audio");
-                value_i("ctx", whisper_model_n_audio_ctx(ctx));
-                value_i("state", whisper_model_n_audio_state(ctx));
-                value_i("head", whisper_model_n_audio_head(ctx));
+                value_i("ctx", whisper_model_n_audio_ctx(ctx), false);
+                value_i("state", whisper_model_n_audio_state(ctx), false);
+                value_i("head", whisper_model_n_audio_head(ctx), false);
                 value_i("layer", whisper_model_n_audio_layer(ctx), true);
-            end_obj();
+            end_obj(false);
             start_obj("text");
-                value_i("ctx", whisper_model_n_text_ctx(ctx));
-                value_i("state", whisper_model_n_text_state(ctx));
-                value_i("head", whisper_model_n_text_head(ctx));
+                value_i("ctx", whisper_model_n_text_ctx(ctx), false);
+                value_i("state", whisper_model_n_text_state(ctx), false);
+                value_i("head", whisper_model_n_text_head(ctx), false);
                 value_i("layer", whisper_model_n_text_layer(ctx), true);
-            end_obj();
-            value_i("mels", whisper_model_n_mels(ctx));
+            end_obj(false);
+            value_i("mels", whisper_model_n_mels(ctx), false);
             value_i("ftype", whisper_model_ftype(ctx), true);
-        end_obj();
+        end_obj(false);
         start_obj("params");
-            value_s("model", params.model.c_str());
-            value_s("language", params.language.c_str());
+            value_s("model", params.model.c_str(), false);
+            value_s("language", params.language.c_str(), false);
             value_b("translate", params.translate, true);
-        end_obj();
+        end_obj(false);
         start_obj("result");
             value_s("language", whisper_lang_str(whisper_full_lang_id(ctx)), true);
-        end_obj();
+        end_obj(false);
         start_arr("transcription");
 
             const int n_segments = whisper_full_n_segments(ctx);
@@ -516,15 +521,15 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
                 const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
                 const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
 
-                start_obj();
+                start_obj(nullptr);
                     start_obj("timestamps");
-                        value_s("from", to_timestamp(t0, true).c_str());
+                        value_s("from", to_timestamp(t0, true).c_str(), false);
                         value_s("to", to_timestamp(t1, true).c_str(), true);
-                    end_obj();
+                    end_obj(false);
                     start_obj("offsets");
-                        value_i("from", t0 * 10);
+                        value_i("from", t0 * 10, false);
                         value_i("to", t1 * 10, true);
-                    end_obj();
+                    end_obj(false);
                     value_s("text", text, true);
                 end_obj(i == (n_segments - 1));
             }
index f2fdb0f7f9c708009f417cf22306beba4c04d91a..3df7b1c71232dd68175d330d1dffbf5da6e0e7a6 100644 (file)
@@ -99,17 +99,17 @@ bool whisper_model_quantize(const std::string & fname_inp, const std::string & f
         fprintf(stderr, "%s: ftype (dst)   = %d\n", __func__, ftype_dst);
         fprintf(stderr, "%s: qntvr (dst)   = %d\n", __func__, GGML_QNT_VERSION);
 
-        fout.write((char *) &hparams.n_vocab,       sizeof(hparams.n_vocab));
-        fout.write((char *) &hparams.n_audio_ctx,   sizeof(hparams.n_audio_ctx));
-        fout.write((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
-        fout.write((char *) &hparams.n_audio_head,  sizeof(hparams.n_audio_head));
-        fout.write((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
-        fout.write((char *) &hparams.n_text_ctx,    sizeof(hparams.n_text_ctx));
-        fout.write((char *) &hparams.n_text_state,  sizeof(hparams.n_text_state));
-        fout.write((char *) &hparams.n_text_head,   sizeof(hparams.n_text_head));
-        fout.write((char *) &hparams.n_text_layer,  sizeof(hparams.n_text_layer));
-        fout.write((char *) &hparams.n_mels,        sizeof(hparams.n_mels));
-        fout.write((char *) &ftype_dst,             sizeof(hparams.ftype));
+        fout.write((const char *) &hparams.n_vocab,       sizeof(hparams.n_vocab));
+        fout.write((const char *) &hparams.n_audio_ctx,   sizeof(hparams.n_audio_ctx));
+        fout.write((const char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
+        fout.write((const char *) &hparams.n_audio_head,  sizeof(hparams.n_audio_head));
+        fout.write((const char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
+        fout.write((const char *) &hparams.n_text_ctx,    sizeof(hparams.n_text_ctx));
+        fout.write((const char *) &hparams.n_text_state,  sizeof(hparams.n_text_state));
+        fout.write((const char *) &hparams.n_text_head,   sizeof(hparams.n_text_head));
+        fout.write((const char *) &hparams.n_text_layer,  sizeof(hparams.n_text_layer));
+        fout.write((const char *) &hparams.n_mels,        sizeof(hparams.n_mels));
+        fout.write((const char *) &ftype_dst,             sizeof(hparams.ftype));
     }
 
     // load mel filters
@@ -138,15 +138,17 @@ bool whisper_model_quantize(const std::string & fname_inp, const std::string & f
         //    return false;
         //}
 
-        std::string word;
+        char word[128];
+
         for (int i = 0; i < n_vocab; i++) {
             uint32_t len;
             finp.read ((char *) &len, sizeof(len));
             fout.write((char *) &len, sizeof(len));
 
-            word.resize(len);
-            finp.read ((char *) word.data(), len);
-            fout.write((char *) word.data(), len);
+            word[len] = '\0';
+
+            finp.read ((char *) word, len);
+            fout.write((char *) word, len);
 
             vocab.token_to_id[word] = i;
             vocab.id_to_token[i] = word;
index 35d2e457cbf8487fd967b6459925a85d1ab0282b..36a251ecce973bddee7df398cd61931fbed22cf9 100644 (file)
@@ -1,8 +1,10 @@
 #include <cstddef>
 #include <cstdint>
+#include <limits>
 #include <stdint.h>
 #include <stdio.h>
 #include <atomic>
+#include <assert.h>
 
 #include <cuda_runtime.h>
 #include <cublas_v2.h>
 #include "ggml-cuda.h"
 #include "ggml.h"
 
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
 static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
 
 #define CUDA_CHECK(err)                                                                 \
@@ -23,18 +29,44 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
         }                                                                               \
     } while (0)
 
+#if CUDART_VERSION >= 12000
+#define CUBLAS_CHECK(err)                                                               \
+    do {                                                                                \
+        cublasStatus_t err_ = (err);                                                    \
+        if (err_ != CUBLAS_STATUS_SUCCESS) {                                            \
+            fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n",                         \
+                    err_, __FILE__, __LINE__, cublasGetStatusString(err_));             \
+            exit(1);                                                                    \
+        }                                                                               \
+    } while (0)
+#else
 #define CUBLAS_CHECK(err)                                                               \
     do {                                                                                \
         cublasStatus_t err_ = (err);                                                    \
         if (err_ != CUBLAS_STATUS_SUCCESS) {                                            \
-            fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);    \
+            fprintf(stderr, "\ncuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);  \
             exit(1);                                                                    \
         }                                                                               \
     } while (0)
+#endif // CUDART_VERSION >= 11
+
+#ifdef GGML_CUDA_DMMV_F16
+typedef half dfloat; // dequantize float
+typedef half2 dfloat2;
+#else
+typedef float dfloat; // dequantize float
+typedef float2 dfloat2;
+#endif //GGML_CUDA_DMMV_F16
 
-typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
+typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
 typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
-typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream);
+typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v);
+typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
+typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+typedef void (*ggml_cuda_op_t)(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, float * src0_ddf_i,
+    float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main);
 
 // QK = number of values after dequantization
 // QR = QK / number of values before dequantization
@@ -83,9 +115,84 @@ typedef struct {
 } block_q8_0;
 static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
 
+//================================= k-quants
+
+#define QK_K 256
+
+typedef struct {
+    uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
+    uint8_t qs[QK_K/4];      // quants
+    half d;                  // super-block scale for quantized scales
+    half dmin;               // super-block scale for quantized mins
+} block_q2_K;
+static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
+
+typedef struct {
+    uint8_t hmask[QK_K/8];
+    uint8_t qs[QK_K/4]; // nibbles / quants
+    uint8_t scales[3*QK_K/64];
+    half d;
+} block_q3_K;
+static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
+
+typedef struct {
+    half d;                    // super-block scale for quantized scales
+    half dmin;                 // super-block scale for quantized mins
+    uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
+    uint8_t qs[QK_K/2];        // 4--bit quants
+} block_q4_K;
+static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
+
+typedef struct {
+    half    d;                   // super-block scale for quantized scales
+    half    dmin;                // super-block scale for quantized mins
+    uint8_t scales[3*QK_K/64];   // scales, quantized with 6 bits
+    uint8_t qh[QK_K/8];          // quants, high bit
+    uint8_t qs[QK_K/2];          // quants, low 4 bits
+} block_q5_K;
+static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
+
+typedef struct {
+    uint8_t ql[QK_K/2];   // quants, lower 4 bits
+    uint8_t qh[QK_K/4];   // quants, upper 2 bits
+    int8_t  scales[QK_K/16]; // scales
+    half    d;         // delta
+} block_q6_K;
+static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
+
+#define WARP_SIZE 32
+
+#define CUDA_ADD_BLOCK_SIZE 256
 #define CUDA_MUL_BLOCK_SIZE 256
+#define CUDA_SILU_BLOCK_SIZE 256
+#define CUDA_CPY_BLOCK_SIZE 32
+#define CUDA_SCALE_BLOCK_SIZE 256
+#define CUDA_ROPE_BLOCK_SIZE 256
+#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
-#define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec
+
+// dmmv = dequantize_mul_mat_vec
+#ifndef GGML_CUDA_DMMV_X
+#define GGML_CUDA_DMMV_X 32
+#endif
+#ifndef GGML_CUDA_DMMV_Y
+#define GGML_CUDA_DMMV_Y 1
+#endif
+
+#ifndef K_QUANTS_PER_ITERATION
+#define K_QUANTS_PER_ITERATION 2
+#else
+static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
+#endif
+
+static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = x[i] + y[i];
+}
 
 static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
@@ -96,713 +203,2101 @@ static __global__ void mul_f32(const float * x, const float * y, float * dst, co
     dst[i] = x[i] * y[i%ky];
 }
 
-static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+static __global__ void silu_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = x[i] / (1.0f + expf(-x[i]));
+}
+
+static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) {
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
+    const int tid = threadIdx.x;
+
+    const float eps = 1e-6;
+
+    float tmp = 0.0f; // partial sum for thread in warp
+
+    for (int i = 0; i < ncols; i += WARP_SIZE) {
+        const int col = i + tid;
+        const float xi = x[row*ncols + col];
+        tmp += xi * xi;
+    }
+
+    // sum up partial sums
+    __syncthreads();
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    const float mean = tmp / ncols;
+    const float scale = 1.0f / sqrtf(mean + eps);
+
+    for (int i = 0; i < ncols; i += WARP_SIZE) {
+        const int col = i + tid;
+        dst[row*ncols + col] = scale * x[row*ncols + col];
+    }
+}
+
+static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
     const block_q4_0 * x = (const block_q4_0 *) vx;
 
-    const float d = x[ib].d;
+    const dfloat d = x[ib].d;
 
-    const uint8_t vui = x[ib].qs[iqs];
+    const int vui = x[ib].qs[iqs];
 
-    const int8_t vi0 = vui & 0xF;
-    const int8_t vi1 = vui >> 4;
+    v.x = vui & 0xF;
+    v.y = vui >> 4;
 
-    v0 = (vi0 - 8)*d;
-    v1 = (vi1 - 8)*d;
+#ifdef GGML_CUDA_DMMV_F16
+    v = __hsub2(v, {8.0f, 8.0f});
+    v = __hmul2(v, {d, d});
+#else
+    v.x = (v.x - 8.0f) * d;
+    v.y = (v.y - 8.0f) * d;
+#endif // GGML_CUDA_DMMV_F16
 }
 
-static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
     const block_q4_1 * x = (const block_q4_1 *) vx;
 
-    const float d = x[ib].d;
-    const float m = x[ib].m;
+    const dfloat d = x[ib].d;
+    const dfloat m = x[ib].m;
 
-    const uint8_t vui = x[ib].qs[iqs];
+    const int vui = x[ib].qs[iqs];
 
-    const int8_t vi0 = vui & 0xF;
-    const int8_t vi1 = vui >> 4;
+    v.x = vui & 0xF;
+    v.y = vui >> 4;
 
-    v0 = vi0*d + m;
-    v1 = vi1*d + m;
+#ifdef GGML_CUDA_DMMV_F16
+    v = __hmul2(v, {d, d});
+    v = __hadd2(v, {m, m});
+#else
+    v.x = (v.x * d) + m;
+    v.y = (v.y * d) + m;
+#endif // GGML_CUDA_DMMV_F16
 }
 
-static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
     const block_q5_0 * x = (const block_q5_0 *) vx;
 
-    const float d = x[ib].d;
+    const dfloat d = x[ib].d;
 
     uint32_t qh;
     memcpy(&qh, x[ib].qh, sizeof(qh));
 
-    const uint8_t xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;
-    const uint8_t xh_1 = ((qh >> (iqs + 12))     ) & 0x10;
+    const int xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;
+    const int xh_1 = ((qh >> (iqs + 12))     ) & 0x10;
 
-    const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0) - 16;
-    const int32_t x1 = ((x[ib].qs[iqs] >>  4) | xh_1) - 16;
+    v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
+    v.y = ((x[ib].qs[iqs] >>  4) | xh_1);
 
-    v0 = x0*d;
-    v1 = x1*d;
+#ifdef GGML_CUDA_DMMV_F16
+    v = __hsub2(v, {16.0f, 16.0f});
+    v = __hmul2(v, {d, d});
+#else
+    v.x = (v.x - 16.0f) * d;
+    v.y = (v.y - 16.0f) * d;
+#endif // GGML_CUDA_DMMV_F16
 }
 
-static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
     const block_q5_1 * x = (const block_q5_1 *) vx;
 
-    const float d = x[ib].d;
-    const float m = x[ib].m;
+    const dfloat d = x[ib].d;
+    const dfloat m = x[ib].m;
 
     uint32_t qh;
     memcpy(&qh, x[ib].qh, sizeof(qh));
 
-    const uint8_t xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;
-    const uint8_t xh_1 = ((qh >> (iqs + 12))     ) & 0x10;
+    const int xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;
+    const int xh_1 = ((qh >> (iqs + 12))     ) & 0x10;
 
-    const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0);
-    const int32_t x1 = ((x[ib].qs[iqs] >>  4) | xh_1);
+    v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
+    v.y = ((x[ib].qs[iqs] >>  4) | xh_1);
 
-    v0 = x0*d + m;
-    v1 = x1*d + m;
+#ifdef GGML_CUDA_DMMV_F16
+    v = __hmul2(v, {d, d});
+    v = __hadd2(v, {m, m});
+#else
+    v.x = (v.x * d) + m;
+    v.y = (v.y * d) + m;
+#endif // GGML_CUDA_DMMV_F16
 }
 
-static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
     const block_q8_0 * x = (const block_q8_0 *) vx;
 
-    const float d = x[ib].d;
+    const dfloat d = x[ib].d;
 
-    const int8_t vi0 = x[ib].qs[iqs + 0];
-    const int8_t vi1 = x[ib].qs[iqs + 1];
+    v.x = x[ib].qs[iqs + 0];
+    v.y = x[ib].qs[iqs + 1];
 
-    v0 = vi0*d;
-    v1 = vi1*d;
+#ifdef GGML_CUDA_DMMV_F16
+    v = __hmul2(v, {d, d});
+#else
+    v.x *= d;
+    v.y *= d;
+#endif // GGML_CUDA_DMMV_F16
 }
 
-static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){
-    const half * x = (const half *) vx;
+//================================== k-quants
 
-    v0 = __half2float(x[ib + 0]);
-    v1 = __half2float(x[ib + 1]);
-}
+static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
 
-template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
-static __global__ void dequantize_block(const void * vx, float * y, const int k) {
-    const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
+    const int i   = blockIdx.x;
+    const int tid = threadIdx.x;
+    const int n   = tid/32;
+    const int l   = tid - 32*n;
+    const int is  = 8*n + l/16;
 
-    if (i >= k) {
-        return;
-    }
+    const block_q2_K * x = (const block_q2_K *) vx;
 
-    const int ib = i/qk; // block index
-    const int iqs = (i%qk)/qr; // quant index
-    const int iybs = i - i%qk; // y block start index
-    const int y_offset = qr == 1 ? 1 : qk/2;
+    const uint8_t q = x[i].qs[32*n + l];
+    float * y = yy + i*QK_K + 128*n;
+
+    float dall = x[i].d;
+    float dmin = x[i].dmin;
+    y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
+    y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
+    y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
+    y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
 
-    // dequantize
-    float & v0 = y[iybs + iqs + 0];
-    float & v1 = y[iybs + iqs + y_offset];
-    dequantize_kernel(vx, ib, iqs, v0, v1);
 }
 
-template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
-static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
-    const int row = blockIdx.x;
-    const int tid = threadIdx.x;
+static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
 
-    const int y_offset = qr == 1 ? 1 : qk/2;
+    int r = threadIdx.x/4;
+    int i = blockIdx.x;
+    int tid = r/2;
+    int is0 = r%2;
+    int l0 = 16*is0 + 4*(threadIdx.x%4);
+    int n = tid / 4;
+    int j = tid - 4*n;
 
-    __shared__ float tmp[block_size]; // separate sum for each thread
-    tmp[tid] = 0;
+    const block_q3_K * x = (const block_q3_K *) vx;
 
-    for (int i = 0; i < ncols/block_size; i += 2) {
-        const int col = i*block_size + 2*tid;
-        const int ib = (row*ncols + col)/qk; // block index
-        const int iqs = (col%qk)/qr; // quant index
-        const int iybs = col - col%qk; // y block start index
+    uint8_t m = 1 << (4*n + j);
+    int is = 8*n + 2*j + is0;
+    int shift = 2*j;
 
-        // dequantize
-        float v0, v1;
-        dequantize_kernel(vx, ib, iqs, v0, v1);
+    int8_t us = is <  4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
+                is <  8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
+                is < 12 ? (x[i].scales[is-8] >>  4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
+                          (x[i].scales[is-8] >>  4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
+    float d_all = x[i].d;
+    float dl = d_all * (us - 32);
 
-        // matrix multiplication
-        tmp[tid] += v0 * y[iybs + iqs + 0];
-        tmp[tid] += v1 * y[iybs + iqs + y_offset];
-    }
+    float * y = yy + i*QK_K + 128*n + 32*j;
+    const uint8_t * q = x[i].qs + 32*n;
+    const uint8_t * hm = x[i].hmask;
 
-    // sum up partial sums and write back result
-    __syncthreads();
-    for (int s=block_size/2; s>0; s>>=1) {
-        if (tid < s) {
-            tmp[tid] += tmp[tid + s];
-        }
-        __syncthreads();
-    }
-    if (tid == 0) {
-        dst[row] = tmp[0];
-    }
-}
+    for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
 
-static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
-    const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
-    mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
 }
 
-static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
-    dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
+    if (j < 4) {
+        d = q[j] & 63; m = q[j + 4] & 63;
+    } else {
+        d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
+        m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);
+    }
 }
 
-static void dequantize_row_q4_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
-    dequantize_block<QK4_1, QR4_1, dequantize_q4_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
+static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
+    const block_q4_K * x = (const block_q4_K *) vx;
 
-static void dequantize_row_q5_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
-    dequantize_block<QK5_0, QR5_0, dequantize_q5_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
+    const int i = blockIdx.x;
 
-static void dequantize_row_q5_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
-    dequantize_block<QK5_1, QR5_1, dequantize_q5_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
+    //// assume 64 threads - this is very slightly better than the one below
+    //const int tid = threadIdx.x;
+    //const int il  = tid/16;
+    //const int ir  = tid%16;
+    //const int is  = 2*il;
+    //const int n   = 2;
 
-static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
-    dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+    // assume 32 threads
+    const int tid = threadIdx.x;
+    const int il  = tid/8;
+    const int ir  = tid%8;
+    const int is  = 2*il;
+    const int n   = 4;
+
+    float * y = yy + i*QK_K + 64*il + n*ir;
+
+    const float dall = x[i].d;
+    const float dmin = x[i].dmin;
+
+    const uint8_t * q = x[i].qs + 32*il + n*ir;
+
+    uint8_t sc, m;
+    get_scale_min_k4(is + 0, x[i].scales, sc, m);
+    const float d1 = dall * sc; const float m1 = dmin * m;
+    get_scale_min_k4(is + 1, x[i].scales, sc, m);
+    const float d2 = dall * sc; const float m2 = dmin * m;
+    for (int l = 0; l < n; ++l) {
+        y[l + 0] = d1 * (q[l] & 0xF) - m1;
+        y[l +32] = d2 * (q[l] >>  4) - m2;
+    }
 }
 
-static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
-    dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, QR4_0, dequantize_q4_0>
-        <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
-}
+static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
+    const block_q5_K * x = (const block_q5_K *) vx;
 
-static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
-    dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, QR4_1, dequantize_q4_1>
-        <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
-}
+    const int i = blockIdx.x;
 
-static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
-    dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, QR5_0, dequantize_q5_0>
-        <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
+    // assume 64 threads - this is very slightly better than the one below
+    const int tid = threadIdx.x;
+    const int il  = tid/16;   // il is in 0...3
+    const int ir  = tid%16;   // ir is in 0...15
+    const int is  = 2*il;     // is is in 0...6
+
+    float * y = yy + i*QK_K + 64*il + 2*ir;
+
+    const float dall = x[i].d;
+    const float dmin = x[i].dmin;
+
+    const uint8_t * ql = x[i].qs + 32*il + 2*ir;
+    const uint8_t * qh = x[i].qh + 2*ir;
+
+    uint8_t sc, m;
+    get_scale_min_k4(is + 0, x[i].scales, sc, m);
+    const float d1 = dall * sc; const float m1 = dmin * m;
+    get_scale_min_k4(is + 1, x[i].scales, sc, m);
+    const float d2 = dall * sc; const float m2 = dmin * m;
+
+    uint8_t   hm  = 1 << (2*il);
+    y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;
+    y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;
+    hm <<= 1;
+    y[32] = d2 * ((ql[ 0] >>  4) + (qh[ 0] & hm ? 16 : 0)) - m2;
+    y[33] = d2 * ((ql[ 1] >>  4) + (qh[ 1] & hm ? 16 : 0)) - m2;
 }
 
-static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
-    dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_1, QR5_1, dequantize_q5_1>
-        <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
-}
+static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
+    const block_q6_K * x = (const block_q6_K *) vx;
 
-static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
-    dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK8_0, QR8_0, dequantize_q8_0>
-        <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
-}
+    const int i = blockIdx.x;
 
-static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
-    dequantize_block<32, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
+    // assume 64 threads - this is very slightly better than the one below
+    const int tid = threadIdx.x;
+    const int ip  = tid/32;   // ip is 0 or 1
+    const int il  = tid - 32*ip; // 0...32
+    const int is  = 8*ip + il/16;
 
-static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
-    dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, 32, 1, convert_f16>
-        <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
-}
+    float * y = yy + i*QK_K + 128*ip + il;
 
-static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
-    switch (type) {
-        case GGML_TYPE_Q4_0:
-            return dequantize_row_q4_0_cuda;
-        case GGML_TYPE_Q4_1:
-            return dequantize_row_q4_1_cuda;
-        case GGML_TYPE_Q5_0:
-            return dequantize_row_q5_0_cuda;
-        case GGML_TYPE_Q5_1:
-            return dequantize_row_q5_1_cuda;
-        case GGML_TYPE_Q8_0:
-            return dequantize_row_q8_0_cuda;
-        case GGML_TYPE_F16:
-            return convert_fp16_to_fp32_cuda;
-        default:
-            return nullptr;
-    }
-}
+    const float d = x[i].d;
 
-static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_type type) {
-    switch (type) {
-        case GGML_TYPE_Q4_0:
-            return dequantize_mul_mat_vec_q4_0_cuda;
-        case GGML_TYPE_Q4_1:
-            return dequantize_mul_mat_vec_q4_1_cuda;
-        case GGML_TYPE_Q5_0:
-            return dequantize_mul_mat_vec_q5_0_cuda;
-        case GGML_TYPE_Q5_1:
-            return dequantize_mul_mat_vec_q5_1_cuda;
-        case GGML_TYPE_Q8_0:
-            return dequantize_mul_mat_vec_q8_0_cuda;
-        case GGML_TYPE_F16:
-            return convert_mul_mat_vec_f16_cuda;
-        default:
-            return nullptr;
-    }
-}
+    const uint8_t * ql = x[i].ql + 64*ip + il;
+    const uint8_t   qh = x[i].qh[32*ip + il];
+    const int8_t  * sc = x[i].scales + is;
 
-// buffer pool for cuda
-#define MAX_CUDA_BUFFERS 256
+    y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
+    y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
+    y[64] = d * sc[4] * ((int8_t)((ql[ 0]  >> 4) | (((qh >> 4) & 3) << 4)) - 32);
+    y[96] = d * sc[6] * ((int8_t)((ql[32]  >> 4) | (((qh >> 6) & 3) << 4)) - 32);
+}
 
-struct scoped_spin_lock {
-    std::atomic_flag& lock;
-    scoped_spin_lock(std::atomic_flag& lock) : lock(lock) {
-        while (lock.test_and_set(std::memory_order_acquire)) {
-            ; // spin
-        }
-    }
-    ~scoped_spin_lock() {
-        lock.clear(std::memory_order_release);
-    }
-    scoped_spin_lock(const scoped_spin_lock&) = delete;
-    scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
-};
+static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
 
-struct cuda_buffer {
-    void * ptr = nullptr;
-    size_t size = 0;
-};
+    static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
 
-static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS];
-static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
+    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    if (row > nrows) return;
 
-static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
-    scoped_spin_lock lock(g_cuda_pool_lock);
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
 
-    for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
-        cuda_buffer& b = g_cuda_buffer_pool[i];
-        if (b.size >= size && b.ptr != nullptr) {
-            void * ptr = b.ptr;
-            *actual_size = b.size;
-            b.ptr = nullptr;
-            b.size = 0;
-            return ptr;
-        }
-    }
-    void * ptr;
-    CUDA_CHECK(cudaMalloc((void **) &ptr, size));
-    *actual_size = size;
-    return ptr;
-}
+    const block_q2_K * x = (const block_q2_K *)vx + ib0;
 
-static void ggml_cuda_pool_free(void * ptr, size_t size) {
-    scoped_spin_lock lock(g_cuda_pool_lock);
+    const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...15
+    const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
 
-    for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
-        cuda_buffer& b = g_cuda_buffer_pool[i];
-        if (b.ptr == nullptr) {
-            b.ptr = ptr;
-            b.size = size;
-            return;
-        }
-    }
-    fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
-    CUDA_CHECK(cudaFree(ptr));
-}
+    const int step = 16/K_QUANTS_PER_ITERATION;
 
-#define GGML_CUDA_MAX_STREAMS 8 // Set this to 1 for reproducible matrix multiplication.
-#define GGML_CUDA_MAX_EVENTS 64
-static cublasHandle_t g_cublasH = nullptr;
-static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_STREAMS] = { nullptr };
-static cudaStream_t g_cudaStreams2[GGML_CUDA_MAX_STREAMS] = { nullptr };
-static cudaEvent_t g_cudaEvents[GGML_CUDA_MAX_EVENTS] = { nullptr };
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0...15 or 0...7
 
-void ggml_init_cublas() {
-    if (g_cublasH == nullptr) {
-        // create streams
-        for (int i = 0; i < GGML_CUDA_MAX_STREAMS; ++i) {
-            CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[i], cudaStreamNonBlocking));
-            CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams2[i], cudaStreamNonBlocking));
-        }
-        // create events
-        for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) {
-            CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents[i], cudaEventDisableTiming));
-        }
+    const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15 or 0...14 in steps of 2
+    const int q_offset = 32*im + l0;
+    const int s_offset = 8*im;
+    const int y_offset = 128*im + l0;
 
-        // create cublas handle
-        CUBLAS_CHECK(cublasCreate(&g_cublasH));
-        CUBLAS_CHECK(cublasSetMathMode(g_cublasH, CUBLAS_TF32_TENSOR_OP_MATH));
+    float tmp = 0; // partial sum for thread in warp
 
-        // configure logging to stdout
-        // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
-    }
-}
+    uint32_t aux[4];
+    const uint8_t * d = (const uint8_t *)aux;
+    const uint8_t * m = (const uint8_t *)(aux + 2);
 
-void * ggml_cuda_host_malloc(size_t size) {
-    if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
-        return nullptr;
-    }
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
 
-    void * ptr = nullptr;
-    cudaError_t err = cudaMallocHost((void **) &ptr, size);
-    if (err != cudaSuccess) {
-        fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
-            size/1024.0/1024.0, cudaGetErrorString(err));
-        return nullptr;
-    }
+        const float   * y = yy + i * QK_K + y_offset;
+        const uint8_t * q = x[i].qs + q_offset;
 
-    return ptr;
-}
+        const float dall = x[i].d;
+        const float dmin = x[i].dmin;
 
-void ggml_cuda_host_free(void * ptr) {
-    CUDA_CHECK(cudaFreeHost(ptr));
-}
+        const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
+        aux[0] = a[0] & 0x0f0f0f0f;
+        aux[1] = a[1] & 0x0f0f0f0f;
+        aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
+        aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
 
-static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) {
-    const uint64_t ne0 = src->ne[0];
-    const uint64_t ne1 = src->ne[1];
-    const uint64_t nb0 = src->nb[0];
-    const uint64_t nb1 = src->nb[1];
-    const uint64_t nb2 = src->nb[2];
-    const uint64_t nb3 = src->nb[3];
-    const enum ggml_type type = src->type;
-    const size_t ts = ggml_type_size(type);
-    const size_t bs = ggml_blck_size(type);
+        float sum1 = 0, sum2 = 0;
+        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+            sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
+                  + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
+                  + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
+                  + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
+                  + y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
+                  + y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
+                  + y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
+                  +y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
+            sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
+                  + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
 
-    const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3);
-    if (nb0 == ts && nb1 == ts*ne0/bs) {
-        return cudaMemcpyAsync(dst, x, ne1*nb1, cudaMemcpyHostToDevice, stream);
-    } else if (nb0 == ts) {
-        return cudaMemcpy2DAsync(dst, ts*ne0/bs, x, nb1, ts*ne0/bs, ne1, cudaMemcpyHostToDevice, stream);
-    } else {
-        for (uint64_t i1 = 0; i1 < ne1; i1++) {
-            const void * rx = (const void *) ((const char *) x + i1*nb1);
-            void * rd = (void *) ((char *) dst + i1*ts*ne0/bs);
-            // pretend the row is a matrix with cols=1
-            cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream);
-            if (r != cudaSuccess) return r;
         }
-        return cudaSuccess;
+        tmp += dall * sum1 - dmin * sum2;
+
     }
-}
 
-static void ggml_cuda_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src1->backend == GGML_BACKEND_CUDA);
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[2];
-    const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
-    const int64_t ne12 = src1->ne[2];
-    const int64_t ne13 = src1->ne[3];
-    const int nb2  = dst->nb[2];
-    const int nb3  = dst->nb[3];
-    size_t x_size, d_size;
-
-    float * d_X = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &x_size); // src0
-    float * d_Y = (float *) src1->data; // src1 is already on device, broadcasted.
-    float * d_D = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &d_size); // dst
-
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            const int i0 = i03*ne02 + i02;
-            float * c_X2 = d_X + i0*ne01*ne00;
-            float * c_D2 = d_D + i0*ne01*ne00;
-
-            cudaStream_t cudaStream = g_cudaStreams[i0 % GGML_CUDA_MAX_STREAMS];
-            cudaStream_t cudaStream2 = g_cudaStreams2[i0 % GGML_CUDA_MAX_STREAMS];
-            cudaEvent_t  cudaEvent = g_cudaEvents[i0 % GGML_CUDA_MAX_EVENTS];
-
-            // copy src0 to device
-            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X2, src0, i03, i02, cudaStream2));
-            CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
-
-            // wait for data
-            CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
-
-            for (int64_t i01 = 0; i01 < ne01; i01++) {
-                const int64_t i13 = i03%ne13;
-                const int64_t i12 = i02%ne12;
-                const int64_t i11 = i01%ne11;
-                const int i1 = i13*ne12*ne11 + i12*ne11 + i11;
-
-                float * c_X1 = c_X2 + i01*ne00;
-                float * c_Y = d_Y + i1*ne10;
-                float * c_D1 = c_D2 + i01*ne00;
-
-                // compute
-                mul_f32_cuda(c_X1, c_Y, c_D1, ne00, ne10, cudaStream);
-                CUDA_CHECK(cudaGetLastError());
-            }
+    // sum up partial sums and write back result
+    __syncthreads();
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
 
-            // copy dst to host
-            float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
-            CUDA_CHECK(cudaMemcpyAsync(d, c_D2, sizeof(float)*ne00*ne01, cudaMemcpyDeviceToHost, cudaStream));
-        }
+    if (tid == 0) {
+        dst[row] = tmp;
     }
-    CUDA_CHECK(cudaDeviceSynchronize());
-    ggml_cuda_pool_free(d_X, x_size);
-    ggml_cuda_pool_free(d_D, d_size);
 }
 
-static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
+static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
 
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
+    const uint16_t kmask1 = 0x0303;
+    const uint16_t kmask2 = 0x0f0f;
 
-    const int nb2  = dst->nb[2];
-    const int nb3  = dst->nb[3];
+    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    if (row > nrows) return;
+
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const block_q3_K * x = (const block_q3_K *)vx + ib0;
+
+    const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
+    const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
+
+    const int n  = K_QUANTS_PER_ITERATION;               // iterations in the inner loop
+    const int step = 16/K_QUANTS_PER_ITERATION;
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0....15 or 0...7
+
+    const uint8_t m = 1 << (4*im);
+
+    const int l0 = n*in;                                 // 0...15 or 0...14 in steps of 2
+    const int q_offset =  32*im + l0;
+    const int y_offset = 128*im + l0;
+
+    uint16_t utmp[4];
+    const int8_t * s = (const int8_t *)utmp;
+
+    const uint16_t s_shift = 4*im;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        const float   * y  = yy + i * QK_K + y_offset;
+        const uint8_t * q = x[i].qs + q_offset;
+        const uint8_t * h = x[i].hmask + l0;
+
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
+        utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
+        utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
+        utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
+
+        const float d = x[i].d;
+
+        float sum = 0;
+        for (int l = 0; l < n; ++l) {
+            sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
+                 + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
+                 + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
+                 + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
+            sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
+                 + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
+                 + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
+                + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
+        }
+        tmp += d * sum;
+
+    }
+
+    // sum up partial sums and write back result
+    __syncthreads();
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (tid == 0) {
+        dst[row] = tmp;
+    }
+}
+
+static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
+    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    if (row > nrows) return;
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
+    const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
+
+    const int step = 8/K_QUANTS_PER_ITERATION;           // 8 or 4
+
+    const int il  = tid/step;                            // 0...3
+    const int ir  = tid - step*il;                       // 0...7 or 0...3
+    const int n   = 2 * K_QUANTS_PER_ITERATION;          // 2 or 4
+
+    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+    const int in = il%2;
+
+    const int l0 = n*(2*ir + in);
+    const int q_offset = 32*im + l0;
+    const int y_offset = 64*im + l0;
+
+    uint16_t aux[4];
+    const uint8_t * sc = (const uint8_t *)aux;
+
+    const block_q4_K * x = (const block_q4_K *)vx + ib0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        const uint8_t * q1 = x[i].qs + q_offset;
+        const uint8_t * q2 = q1 + 64;
+        const float   * y1 = yy + i*QK_K + y_offset;
+        const float   * y2 = y1 + 128;
+
+        const float dall = x[i].d;
+        const float dmin = x[i].dmin;
+
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        aux[0] = a[im+0] & kmask1;
+        aux[1] = a[im+2] & kmask1;
+        aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
+        aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
+
+        float4 s = {0.f, 0.f, 0.f, 0.f};
+        float smin = 0;
+        for (int l = 0; l < n; ++l) {
+            s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4);
+            s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4);
+            smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
+        }
+        tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
+
+    }
+
+    // sum up partial sums and write back result
+    __syncthreads();
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (tid == 0) {
+        dst[row] = tmp;
+    }
+}
+
+static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) {
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
+    //const int row = blockIdx.x*blockDim.y + threadIdx.y;
+    const int row = blockIdx.x;
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const int tid = threadIdx.x/2;  // 0...15
+    const int ix  = threadIdx.x%2;
+
+    const int il  = tid/4;     // 0...3
+    const int ir  = tid - 4*il;// 0...3
+    const int n   = 2;
+
+    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+    const int in = il%2;
+
+    const int l0 = n*(2*ir + in);
+    const int q_offset = 32*im + l0;
+    const int y_offset = 64*im + l0;
+
+    const uint8_t hm1  = 1 << (2*im);
+    const uint8_t hm2  = hm1 << 4;
+
+    uint16_t aux[4];
+    const uint8_t * sc = (const uint8_t *)aux;
+
+    const block_q5_K * x = (const block_q5_K *)vx + ib0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    for (int i = ix; i < num_blocks_per_row; i += 2) {
+
+        const uint8_t * ql1 = x[i].qs + q_offset;
+        const uint8_t * ql2 = ql1 + 64;
+        const uint8_t * qh  = x[i].qh + l0;
+        const float   * y1  = yy + i*QK_K + y_offset;
+        const float   * y2  = y1 + 128;
+
+        const float dall = x[i].d;
+        const float dmin = x[i].dmin;
+
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        aux[0] = a[im+0] & kmask1;
+        aux[1] = a[im+2] & kmask1;
+        aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
+        aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
+
+        float4 sum = {0.f, 0.f, 0.f, 0.f};
+        float smin = 0;
+        for (int l = 0; l < n; ++l) {
+            sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
+                   + y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0));
+            sum.y += y1[l+32] * ((ql1[l+ 0] >>  4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
+                   + y1[l+48] * ((ql1[l+16] >>  4) + (qh[l+16] & (hm1 << 1) ? 16 : 0));
+            sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
+                   + y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0));
+            sum.w += y2[l+32] * ((ql2[l+ 0] >>  4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
+                   + y2[l+48] * ((ql2[l+16] >>  4) + (qh[l+16] & (hm2 << 1) ? 16 : 0));
+            smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
+                  + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
+        }
+        tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
+
+    }
+
+    // sum up partial sums and write back result
+    __syncthreads();
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (tid == 0) {
+        dst[row] = tmp;
+    }
+}
+
+static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
+
+    static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
+
+    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    if (row > nrows) return;
+
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const block_q6_K * x = (const block_q6_K *)vx + ib0;
+
+    const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
+    const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0, 1
+
+    const int step = 16/K_QUANTS_PER_ITERATION;          // 16 or 8
+
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0...15 or 0...7
+
+#if K_QUANTS_PER_ITERATION == 1
+    const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15
+    const int is = 0;
+#else
+    const int l0 = 4 * in;                               // 0, 4, 8, ..., 28
+    const int is = in / 4;
+#endif
+    const int ql_offset = 64*im + l0;
+    const int qh_offset = 32*im + l0;
+    const int s_offset  =  8*im + is;
+    const int y_offset = 128*im + l0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        const float   * y  = yy + i * QK_K + y_offset;
+        const uint8_t * ql = x[i].ql + ql_offset;
+        const uint8_t * qh = x[i].qh + qh_offset;
+        const int8_t  * s  = x[i].scales + s_offset;
+
+        const float d = x[i].d;
+
+#if K_QUANTS_PER_ITERATION == 1
+        float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
+                  + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
+                  + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
+                  + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
+                  + y[64] * s[4] * d * ((int8_t)((ql[ 0]  >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
+                  + y[80] * s[5] * d * ((int8_t)((ql[16]  >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
+                  + y[96] * s[6] * d * ((int8_t)((ql[32]  >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
+                  +y[112] * s[7] * d * ((int8_t)((ql[48]  >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
+        tmp += sum;
+#else
+        float sum = 0;
+        for (int l = 0; l < 4; ++l) {
+            sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
+                 + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
+                 + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
+                 + y[l+96] * s[6] * d * ((int8_t)((ql[l+32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
+        }
+        tmp += sum;
+#endif
+
+    }
+
+    // sum up partial sums and write back result
+    __syncthreads();
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (tid == 0) {
+        dst[row] = tmp;
+    }
+}
+
+static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
+    const half * x = (const half *) vx;
+
+    // automatic half -> float type cast if dfloat == float
+    v.x = x[ib + iqs + 0];
+    v.y = x[ib + iqs + 1];
+}
+
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
+static __global__ void dequantize_block(const void * vx, float * y, const int k) {
+    const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    const int ib = i/qk; // block index
+    const int iqs = (i%qk)/qr; // quant index
+    const int iybs = i - i%qk; // y block start index
+    const int y_offset = qr == 1 ? 1 : qk/2;
+
+    // dequantize
+    dfloat2 v;
+    dequantize_kernel(vx, ib, iqs, v);
+
+    y[iybs + iqs + 0]        = v.x;
+    y[iybs + iqs + y_offset] = v.y;
+}
+
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
+static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) {
+    // qk = quantized weights per x block
+    // qr = number of quantized weights per data value in x block
+    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+
+    if (row >= nrows) {
+        return;
+    }
+
+    const int tid = threadIdx.x;
+
+    const int iter_stride = 2*GGML_CUDA_DMMV_X;
+    const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
+    const int y_offset = qr == 1 ? 1 : qk/2;
+
+// partial sum for each thread
+#ifdef GGML_CUDA_DMMV_F16
+    half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
+#else
+    float tmp = 0.0f;
+#endif // GGML_CUDA_DMMV_F16
+
+    for (int i = 0; i < ncols; i += iter_stride) {
+        const int col = i + vals_per_iter*tid;
+        const int ib = (row*ncols + col)/qk; // x block index
+        const int iqs = (col%qk)/qr; // x quant index
+        const int iybs = col - col%qk; // y block start index
+
+// processing >2 values per i iter is faster for fast GPUs
+#pragma unroll
+        for (int j = 0; j < vals_per_iter; j += 2) {
+            // process 2 vals per j iter
+
+            // dequantize
+            // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
+            dfloat2 v;
+            dequantize_kernel(vx, ib, iqs + j/qr, v);
+
+            // matrix multiplication
+            // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
+#ifdef GGML_CUDA_DMMV_F16
+            tmp += __hmul2(v, {
+                y[iybs + iqs + j/qr + 0],
+                y[iybs + iqs + j/qr + y_offset]
+            });
+#else
+            tmp += v.x * y[iybs + iqs + j/qr + 0];
+            tmp += v.y * y[iybs + iqs + j/qr + y_offset];
+#endif // GGML_CUDA_DMMV_F16
+        }
+    }
+
+    // sum up partial sums and write back result
+    __syncthreads();
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (tid == 0) {
+#ifdef GGML_CUDA_DMMV_F16
+        dst[row] = tmp.x + tmp.y;
+#else
+        dst[row] = tmp;
+#endif // GGML_CUDA_DMMV_F16
+    }
+}
+
+static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
+    const half * x = (half *) vx;
+
+    const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
+    const int channel = blockDim.z*blockIdx.z + threadIdx.z;
+
+    const int nrows_y = ncols_x;
+    const int nrows_dst = nrows_x;
+    const int row_dst = row_x;
+
+    float tmp = 0.0f;
+
+    for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
+        const int col_x = col_x0 + threadIdx.x;
+
+        if (col_x >= ncols_x) {
+            break;
+        }
+
+        // x is transposed and permuted
+        const int ix = row_x*nchannels_x*ncols_x + channel*ncols_x + col_x;
+        const float xi = __half2float(x[ix]);
+
+        const int row_y = col_x;
+
+
+        // y is not transposed but permuted
+        const int iy = channel*nrows_y + row_y;
+
+        tmp += xi * y[iy];
+    }
+
+    // dst is not transposed and not permuted
+    const int idst = channel*nrows_dst + row_dst;
+
+    // sum up partial sums and write back result
+    __syncthreads();
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (threadIdx.x == 0) {
+        dst[idst] = tmp;
+    }
+}
+
+static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
+    const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
+    const int row_stride_x, const int nchannels_x, const int channel_stride_x) {
+
+    const half * x = (half *) vx;
+
+    const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
+    const int channel = blockDim.z*blockIdx.z + threadIdx.z;
+
+    const int nrows_y = ncols_x;
+    const int nrows_dst = nrows_x;
+    const int row_dst = row_x;
+
+    const int idst = channel*nrows_dst + row_dst;
+
+    float tmp = 0.0f;
+
+    for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
+        const int col_x = col_x0 + threadIdx.x;
+
+        if (col_x >= ncols_x) {
+            break;
+        }
+
+        const int ix = channel*channel_stride_x + row_x*row_stride_x + col_x;
+        const float xi = __half2float(x[ix]);
+
+        const int row_y = col_x;
+
+        const int iy = channel*nrows_y + row_y;
+
+        tmp += xi * y[iy];
+    }
+
+    // sum up partial sums and write back result
+    __syncthreads();
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (threadIdx.x == 0) {
+        dst[idst] = tmp;
+    }
+}
+
+static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
+    const float * xi = (float *) cxi;
+    float * dsti = (float *) cdsti;
+
+    *dsti = *xi;
+}
+
+static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
+    const float * xi = (float *) cxi;
+    half * dsti = (half *) cdsti;
+
+    *dsti = __float2half(*xi);
+}
+
+template <cpy_kernel_t cpy_1>
+static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
+                                   const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+                                   const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= ne) {
+        return;
+    }
+
+    // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
+    // then combine those indices with the corresponding byte offsets to get the total offsets
+    const int i02 = i / (ne00*ne01);
+    const int i01 = (i - i02*ne01*ne00) / ne00;
+    const int i00 = i - i02*ne01*ne00 - i01*ne00;
+    const int x_offset = i00*nb00 + i01*nb01 + i02*nb02;
+
+    const int i12 = i / (ne10*ne11);
+    const int i11 = (i - i12*ne10*ne11) / ne10;
+    const int i10 = i - i12*ne10*ne11 - i11*ne10;
+    const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
+
+    cpy_1(cx + x_offset, cdst + dst_offset);
+}
+
+// rope == RoPE == rotary positional embedding
+static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p, const float theta_scale) {
+    const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);
+
+    if (col >= ncols) {
+        return;
+    }
+
+    const int row = blockDim.y*blockIdx.y + threadIdx.y;
+    const int i = row*ncols + col;
+
+    const float theta = p*powf(theta_scale, col/2);
+    const float sin_theta = sinf(theta);
+    const float cos_theta = cosf(theta);
+
+    const float x0 = x[i + 0];
+    const float x1 = x[i + 1];
+
+    dst[i + 0] = x0*cos_theta - x1*sin_theta;
+    dst[i + 1] = x0*sin_theta + x1*cos_theta;
+}
+
+static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
+    const int col = blockDim.x*blockIdx.x + threadIdx.x;
+    const int row = blockDim.y*blockIdx.y + threadIdx.y;
+
+    if (col >= ncols) {
+        return;
+    }
+
+    const int i = row*ncols + col;
+    // dst[i] = col > n_past + row ? -INFINITY : x[i];
+    dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
+}
+
+// the CUDA soft max implementation differs from the CPU implementation
+// instead of doubles floats are used
+// values are also not normalized to the maximum value by subtracting it in the exponential function
+// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
+static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
+    const int row = blockDim.y*blockIdx.y + threadIdx.y;
+    const int block_size = blockDim.x;
+    const int tid = threadIdx.x;
+
+    float tmp = 0.0;
+
+    for (int block_start = 0; block_start < ncols; block_start += block_size) {
+        const int col = block_start + tid;
+
+        if (col >= ncols) {
+            break;
+        }
+
+        const int i = row*ncols + col;
+        const float val = expf(x[i]);
+        tmp += val;
+        dst[i] = val;
+    }
+
+    // sum up partial sums
+    __syncthreads();
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    for (int block_start = 0; block_start < ncols; block_start += block_size) {
+        const int col = block_start + tid;
+
+        if (col >= ncols) {
+            break;
+        }
+
+        const int i = row*ncols + col;
+        dst[i] /= tmp;
+    }
+}
+
+static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    dst[i] = scale * x[i];
+}
+
+static void add_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
+    add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
+}
+
+static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
+    const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
+    mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
+}
+
+static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
+    silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % WARP_SIZE == 0);
+    const dim3 block_dims(WARP_SIZE, 1, 1);
+    rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+}
+
+static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+    dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+}
+
+static void dequantize_row_q4_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+    dequantize_block<QK4_1, QR4_1, dequantize_q4_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+}
+
+static void dequantize_row_q5_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+    dequantize_block<QK5_0, QR5_0, dequantize_q5_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+}
+
+static void dequantize_row_q5_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+    dequantize_block<QK5_1, QR5_1, dequantize_q5_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+}
+
+static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+    dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+}
+
+static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
+}
+
+static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
+}
+
+static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
+}
+
+static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
+}
+
+static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+    dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
+        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+    dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+    dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
+        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+    dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+    dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
+        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
+    const int block_num_y = (nrows + ny - 1) / ny;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(32, ny, 1);
+    dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int ny = 2 / K_QUANTS_PER_ITERATION;
+    const int block_num_y = (nrows + ny - 1) / ny;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(32, ny, 1);
+    dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int ny = 2 / K_QUANTS_PER_ITERATION;
+    const int block_num_y = (nrows + ny - 1) / ny;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(32, ny, 1);
+    dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const dim3 block_dims(32, 1, 1);
+    dequantize_mul_mat_vec_q5_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
+}
+
+static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int ny = 2 / K_QUANTS_PER_ITERATION;
+    const int block_num_y = (nrows + ny - 1) / ny;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(32, ny, 1);
+    dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+    dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+}
+
+static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+    dequantize_mul_mat_vec<1, 1, convert_f16>
+        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+            return dequantize_row_q4_0_cuda;
+        case GGML_TYPE_Q4_1:
+            return dequantize_row_q4_1_cuda;
+        case GGML_TYPE_Q5_0:
+            return dequantize_row_q5_0_cuda;
+        case GGML_TYPE_Q5_1:
+            return dequantize_row_q5_1_cuda;
+        case GGML_TYPE_Q8_0:
+            return dequantize_row_q8_0_cuda;
+        case GGML_TYPE_Q2_K:
+            return dequantize_row_q2_K_cuda;
+        case GGML_TYPE_Q3_K:
+            return dequantize_row_q3_K_cuda;
+        case GGML_TYPE_Q4_K:
+            return dequantize_row_q4_K_cuda;
+        case GGML_TYPE_Q5_K:
+            return dequantize_row_q5_K_cuda;
+        case GGML_TYPE_Q6_K:
+            return dequantize_row_q6_K_cuda;
+        case GGML_TYPE_F16:
+            return convert_fp16_to_fp32_cuda;
+        default:
+            return nullptr;
+    }
+}
+
+static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) {
+    const dim3 block_nums(1, nrows_x, nchannels_x);
+    const dim3 block_dims(WARP_SIZE, 1, 1);
+    mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x);
+}
+
+static void ggml_mul_mat_vec_nc_f16_f32_cuda(
+    const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x,
+    const int nchannels_x, const int channel_stride_x, cudaStream_t stream) {
+
+    const dim3 block_nums(1, nrows_x, nchannels_x);
+    const dim3 block_dims(WARP_SIZE, 1, 1);
+    mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
+        (vx, y, dst, ncols_x, nrows_x, row_stride_x, nchannels_x, channel_stride_x);
+}
+
+static void ggml_cpy_f32_f32_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+    const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+
+    const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+    cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+}
+
+static void ggml_cpy_f32_f16_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+    const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+
+    const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+    cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+}
+
+static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
+    scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
+}
+
+static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float theta_scale, cudaStream_t stream) {
+    GGML_ASSERT(nrows % 2 == 0);
+    const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1);
+    const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+    const dim3 block_nums(num_blocks_x, nrows, 1);
+    rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, theta_scale);
+}
+
+static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
+    const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
+    const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
+    const dim3 block_nums(block_num_x, nrows_x, 1);
+    diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
+}
+
+static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
+    const dim3 block_dims(WARP_SIZE, 1, 1);
+    const dim3 block_nums(1, nrows_x, 1);
+    soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
+}
+
+// buffer pool for cuda
+#define MAX_CUDA_BUFFERS 256
+
+struct scoped_spin_lock {
+    std::atomic_flag& lock;
+    scoped_spin_lock(std::atomic_flag& lock) : lock(lock) {
+        while (lock.test_and_set(std::memory_order_acquire)) {
+            ; // spin
+        }
+    }
+    ~scoped_spin_lock() {
+        lock.clear(std::memory_order_release);
+    }
+    scoped_spin_lock(const scoped_spin_lock&) = delete;
+    scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
+};
+
+struct cuda_buffer {
+    void * ptr = nullptr;
+    size_t size = 0;
+};
+
+static cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS];
+static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
+
+static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
+    scoped_spin_lock lock(g_cuda_pool_lock);
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+
+    for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
+        cuda_buffer& b = g_cuda_buffer_pool[id][i];
+        if (b.size >= size && b.ptr != nullptr) {
+            void * ptr = b.ptr;
+            *actual_size = b.size;
+            b.ptr = nullptr;
+            b.size = 0;
+            return ptr;
+        }
+    }
+    void * ptr;
+    CUDA_CHECK(cudaMalloc((void **) &ptr, size));
+    *actual_size = size;
+    return ptr;
+}
+
+static void ggml_cuda_pool_free(void * ptr, size_t size) {
+    scoped_spin_lock lock(g_cuda_pool_lock);
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+
+    for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
+        cuda_buffer& b = g_cuda_buffer_pool[id][i];
+        if (b.ptr == nullptr) {
+            b.ptr = ptr;
+            b.size = size;
+            return;
+        }
+    }
+    fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
+    CUDA_CHECK(cudaFree(ptr));
+}
+
+
+static void * g_scratch_buffer = nullptr;
+static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default
+static size_t g_scratch_offset = 0;
+
+static int g_device_count = -1;
+static int g_main_device = 0;
+static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
+
+static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
+
+static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES] = { nullptr };
+
+void ggml_init_cublas() {
+    static bool initialized = false;
+
+    if (!initialized) {
+        CUDA_CHECK(cudaGetDeviceCount(&g_device_count));
+        GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
+        int64_t total_vram = 0;
+        fprintf(stderr, "%s: found %d CUDA devices:\n", __func__, g_device_count);
+        for (int id = 0; id < g_device_count; ++id) {
+            cudaDeviceProp prop;
+            CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
+            fprintf(stderr, "  Device %d: %s\n", id, prop.name);
+            g_tensor_split[id] = total_vram;
+            total_vram += prop.totalGlobalMem;
+        }
+        for (int id = 0; id < g_device_count; ++id) {
+            g_tensor_split[id] /= total_vram;
+        }
+
+        for (int id = 0; id < g_device_count; ++id) {
+            CUDA_CHECK(cudaSetDevice(id));
+
+            // create main stream
+            CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id], cudaStreamNonBlocking));
+
+            // create cublas handle
+            CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
+            CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
+        }
+
+        // configure logging to stdout
+        // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
+
+        initialized = true;
+    }
+}
+
+void ggml_cuda_set_tensor_split(const float * tensor_split) {
+    bool all_zero = true;
+    for (int i = 0; i < g_device_count; ++i) {
+        if (tensor_split[i] != 0.0f) {
+            all_zero = false;
+            break;
+        }
+    }
+    if (all_zero) {
+        return;
+    }
+    float split_sum = 0.0f;
+    for (int i = 0; i < g_device_count; ++i) {
+        g_tensor_split[i] = split_sum;
+        split_sum += tensor_split[i];
+    }
+    for (int i = 0; i < g_device_count; ++i) {
+        g_tensor_split[i] /= split_sum;
+    }
+}
+
+void * ggml_cuda_host_malloc(size_t size) {
+    if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
+        return nullptr;
+    }
+
+    void * ptr = nullptr;
+    cudaError_t err = cudaMallocHost((void **) &ptr, size);
+    if (err != cudaSuccess) {
+        // The allocation error can be bypassed. A null ptr will assigned out of this function.
+        // This can fixed the OOM error in WSL.
+        cudaGetLastError();
+        fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
+            size/1024.0/1024.0, cudaGetErrorString(err));
+        return nullptr;
+    }
+
+    return ptr;
+}
+
+void ggml_cuda_host_free(void * ptr) {
+    CUDA_CHECK(cudaFreeHost(ptr));
+}
+
+static cudaError_t ggml_cuda_cpy_tensor_2d(
+    void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
+
+    cudaMemcpyKind kind;
+    char * src_ptr;
+    if (src->backend == GGML_BACKEND_CPU) {
+        kind = cudaMemcpyHostToDevice;
+        src_ptr = (char *) src->data;
+    } else if (src->backend == GGML_BACKEND_GPU) {
+        kind = cudaMemcpyDeviceToDevice;
+        struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
+        int id;
+        CUDA_CHECK(cudaGetDevice(&id));
+        src_ptr = (char *) extra->data_device[id];
+    } else {
+        GGML_ASSERT(false);
+    }
+    char * dst_ptr = (char *) dst;
+
+    const int64_t ne0 = src->ne[0];
+    const int64_t nb0 = src->nb[0];
+    const int64_t nb1 = src->nb[1];
+    const int64_t nb2 = src->nb[2];
+    const int64_t nb3 = src->nb[3];
+    const enum ggml_type type = src->type;
+    const int64_t ts = ggml_type_size(type);
+    const int64_t bs = ggml_blck_size(type);
+    int64_t i1_diff = i1_high - i1_low;
+
+    const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
+    if (nb0 == ts && nb1 == ts*ne0/bs) {
+        return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream);
+    } else if (nb0 == ts) {
+        return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream);
+    } else {
+        for (int64_t i1 = 0; i1 < i1_diff; i1++) {
+            const void * rx = (const void *) ((const char *) x + i1*nb1);
+            void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
+            // pretend the row is a matrix with cols=1
+            cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, kind, stream);
+            if (r != cudaSuccess) return r;
+        }
+        return cudaSuccess;
+    }
+}
+
+inline void ggml_cuda_op_add(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
+
+    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(src1_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
+
+    const int64_t ne0 = src0->ne[0];
+    const int64_t i01_diff = i01_high - i01_low;
+
+    // compute
+    add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
+    CUDA_CHECK(cudaGetLastError());
+
+    (void) src1;
+    (void) dst;
+    (void) src0_ddq_i;
+    (void) i02;
+    (void) i1;
+}
+
+inline void ggml_cuda_op_mul(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
+
+    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(src1_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
+
+    const int64_t ne00 = src0->ne[0];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+
+    for (int64_t i01 = i01_low; i01 < i01_high; i01++) {
+        const int64_t i11 = i1*ne11 + i01%ne11; // broadcast src1 across src0
 
-    const float alpha = 1.0f;
-    const float beta = 0.0f;
-    const int x_ne = ne01 * ne00;
-    const int y_ne = ne11 * ne10;
-    const int d_ne = ne11 * ne01;
-    const int n_mm = ne03 * ne02;
+        float * src0_ddf_i01 = src0_ddf_i + i01*ne00;
+        float * src1_ddf_i01 = src1_ddf_i + i11*ne10;
+        float * dst_ddf_i01 = dst_ddf_i + i01*ne00;
+
+        // compute
+        mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
+        CUDA_CHECK(cudaGetLastError());
+    }
 
-    size_t x_size, y_size, d_size;
-    float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
-    float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
-    float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
+    (void) dst;
+    (void) src0_ddq_i;
+    (void) i02;
+}
 
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            int i = i03*ne02 + i02;
-            cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
+inline void ggml_cuda_op_silu(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
 
-            float * c_X = d_X + i * x_ne;
-            float * c_Y = d_Y + i * y_ne;
-            float * c_D = d_D + i * d_ne;
+    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
 
-            // copy data to device
-            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream));
-            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
+    const int64_t ne00 = src0->ne[0];
+    const int64_t i01_diff = i01_high - i01_low;
+
+    // compute
+    silu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
+    CUDA_CHECK(cudaGetLastError());
+
+    (void) src1;
+    (void) dst;
+    (void) src0_ddq_i;
+    (void) src1_ddf_i;
+    (void) i02;
+    (void) i1;
+}
 
-            // compute
-            CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
-            CUBLAS_CHECK(
-                cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
-                        ne01, ne11, ne10,
-                        &alpha, c_X, ne00,
-                                c_Y, ne10,
-                        &beta,  c_D, ne01));
+inline void ggml_cuda_op_rms_norm(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
 
-            // copy dst to host
-            float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
-            CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
-        }
+    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t i01_diff = i01_high - i01_low;
+
+    // compute
+    rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
+    CUDA_CHECK(cudaGetLastError());
+
+    (void) src1;
+    (void) dst;
+    (void) src0_ddq_i;
+    (void) src1_ddf_i;
+    (void) i02;
+    (void) i1;
+}
+
+inline void ggml_cuda_op_dequantize_mul_mat_vec(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
+
+    GGML_ASSERT(src0_ddq_i != nullptr);
+    GGML_ASSERT(src1_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t nrows = i01_high - i01_low;
+
+// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
+#ifdef GGML_CUDA_DMMV_F16
+    size_t ash;
+    dfloat * src1_dfloat = nullptr; // dfloat == half
+
+    bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
+        src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
+        src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
+
+    if (src1_convert_f16) {
+        src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
+        ggml_cpy_f32_f16_cuda((char *) src1_ddf_i, (char *) src1_dfloat, ne00,
+                                ne00, 1, sizeof(float), 0, 0,
+                                ne00, 1, sizeof(half),  0, 0, cudaStream_main);
+    }
+#else
+    dfloat * src1_dfloat = src1_ddf_i; // dfloat == float, no conversion
+#endif // GGML_CUDA_DMMV_F16
+
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+            dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
+            break;
+        case GGML_TYPE_Q4_1:
+            dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
+            break;
+        case GGML_TYPE_Q5_0:
+            dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
+            break;
+        case GGML_TYPE_Q5_1:
+            dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
+            break;
+        case GGML_TYPE_Q8_0:
+            dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
+            break;
+        case GGML_TYPE_Q2_K:
+            dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
+            break;
+        case GGML_TYPE_Q3_K:
+            dequantize_mul_mat_vec_q3_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
+            break;
+        case GGML_TYPE_Q4_K:
+            dequantize_mul_mat_vec_q4_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
+            break;
+        case GGML_TYPE_Q5_K:
+            dequantize_mul_mat_vec_q5_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
+            break;
+        case GGML_TYPE_Q6_K:
+            dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
+            break;
+        case GGML_TYPE_F16:
+            convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
+            break;
+        default:
+            GGML_ASSERT(false);
+            break;
+    }
+    CUDA_CHECK(cudaGetLastError());
+
+#ifdef GGML_CUDA_DMMV_F16
+    if (src1_convert_f16) {
+        ggml_cuda_pool_free(src1_dfloat, ash);
     }
+#endif // GGML_CUDA_DMMV_F16
 
-    CUDA_CHECK(cudaDeviceSynchronize());
-    ggml_cuda_pool_free(d_X, x_size);
-    ggml_cuda_pool_free(d_Y, y_size);
-    ggml_cuda_pool_free(d_D, d_size);
+    (void) src1;
+    (void) dst;
+    (void) src0_ddf_i;
+    (void) i02;
+    (void) i1;
 }
 
-static void ggml_cuda_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) {
+inline void ggml_cuda_op_mul_mat_cublas(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
+
+    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(src1_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
+
+    const float alpha = 1.0f;
+    const float beta = 0.0f;
+
     const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
 
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
 
-    const int nb10 = src1->nb[0];
-    const int nb11 = src1->nb[1];
-    const int nb12 = src1->nb[2];
-    const int nb13 = src1->nb[3];
+    const int64_t ne0 = dst->ne[0];
+    const int64_t i01_diff = i01_high - i01_low;
+
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+
+    // the main device has a larger memory buffer to hold the results from all GPUs
+    // ldc == nrows of the matrix that cuBLAS writes into
+    int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : i01_diff;
+
+    CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], cudaStream_main));
+    CUBLAS_CHECK(
+        cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
+                i01_diff, ne11, ne10,
+                &alpha, src0_ddf_i, ne00,
+                        src1_ddf_i, ne10,
+                &beta,  dst_ddf_i,  ldc));
+
+    (void) dst;
+    (void) src0_ddq_i;
+    (void) i02;
+    (void) i1;
+}
 
-    const int nb2  = dst->nb[2];
-    const int nb3  = dst->nb[3];
+inline void ggml_cuda_op_rope(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
 
-    const float alpha = 1.0f;
-    const float beta = 0.0f;
-    const int x_ne = ne01 * ne00;
-    const int y_ne = ne11 * ne10;
-    const int d_ne = ne11 * ne01;
-    const int n_mm = ne03 * ne02;
-
-    size_t x_size, y_size, d_size;
-    half  * d_X =  (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * x_ne, &x_size);
-    half  * d_Y =  (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * y_ne, &y_size);
-    float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
-
-    bool src1_cont_rows = nb10 == sizeof(float);
-    bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
-
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            int i = i03*ne02 + i02;
-            cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
-
-            half  * c_X = d_X + i * x_ne;
-            half  * c_Y = d_Y + i * y_ne;
-            float * c_D = d_D + i * d_ne;
-
-            // copy src0 to device
-            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream));
-
-            // convert src1 to fp16
-            // TODO: use multiple threads
-            ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i03 * ne02 + i02);
-            char * src1i = (char *) src1->data + i03*nb13 + i02*nb12;
-            if (src1_cont_rows) {
-                if (src1_cont_cols) {
-                    ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11);
-                }
-                else {
-                    for (int64_t i01 = 0; i01 < ne11; i01++) {
-                        ggml_fp32_to_fp16_row((float *) (src1i + i01*nb11), tmp + i01*ne10, ne10);
-                    }
-                }
-            }
-            else {
-                for (int64_t i01 = 0; i01 < ne11; i01++) {
-                    for (int64_t i00 = 0; i00 < ne10; i00++) {
-                        // very slow due to no inlining
-                        tmp[i01*ne10 + i00] = ggml_fp32_to_fp16(*(float *) (src1i + i01*nb11 + i00*nb10));
-                    }
-                }
-            }
+    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
 
-            // copy src1 to device
-            CUDA_CHECK(cudaMemcpyAsync(c_Y, tmp, sizeof(half) * y_ne, cudaMemcpyHostToDevice, cudaStream));
+    const int64_t ne00 = src0->ne[0];
+    const int64_t i01_diff = i01_high - i01_low;
 
-            // compute
-            CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
-            CUBLAS_CHECK(
-                cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
-                        ne01, ne11, ne10,
-                        &alpha, c_X, CUDA_R_16F, ne00,
-                                c_Y, CUDA_R_16F, ne10,
-                        &beta,  c_D, CUDA_R_32F, ne01,
-                        CUBLAS_COMPUTE_32F_FAST_16F,
-                        CUBLAS_GEMM_DEFAULT));
+    const int n_past = ((int32_t *) src1->data)[0];
+    const int n_dims = ((int32_t *) src1->data)[1];
+    const int mode   = ((int32_t *) src1->data)[2];
+    GGML_ASSERT(mode == 0);
 
-            // copy dst to host
-            float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
-            CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
-        }
-    }
+    const float theta_scale = powf(10000.0, -2.0f/n_dims);
+    const float p = ((mode & 1) == 0 ? n_past + i02 : i02);
+
+    // compute
+    rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
+    CUDA_CHECK(cudaGetLastError());
+
+    (void) dst;
+    (void) src0_ddq_i;
+    (void) src1_ddf_i;
+    (void) i1;
+}
+
+inline void ggml_cuda_op_diag_mask_inf(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
+
+    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t i01_diff = i01_high - i01_low;
+
+    const int n_past = ((int32_t *) src1->data)[0];
+
+    // compute
+    diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main);
+    CUDA_CHECK(cudaGetLastError());
+
+    (void) dst;
+    (void) src0_ddq_i;
+    (void) src1_ddf_i;
+    (void) i02;
+    (void) i1;
+}
+
+inline void ggml_cuda_op_soft_max(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
 
-    CUDA_CHECK(cudaDeviceSynchronize());
-    ggml_cuda_pool_free(d_X, x_size);
-    ggml_cuda_pool_free(d_Y, y_size);
-    ggml_cuda_pool_free(d_D, d_size);
+    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t i01_diff = i01_high - i01_low;
+
+    // compute
+    soft_max_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
+    CUDA_CHECK(cudaGetLastError());
+
+    (void) src1;
+    (void) dst;
+    (void) src0_ddq_i;
+    (void) src1_ddf_i;
+    (void) i02;
+    (void) i1;
+}
+
+inline void ggml_cuda_op_scale(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
+
+    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
+
+    const float scale = ((float *) src1->data)[0];
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t i01_diff = i01_high - i01_low;
+
+    // compute
+    scale_f32_cuda(src0_ddf_i, dst_ddf_i, scale, ne00*i01_diff, cudaStream_main);
+    CUDA_CHECK(cudaGetLastError());
+
+    (void) src1;
+    (void) dst;
+    (void) src0_ddq_i;
+    (void) src1_ddf_i;
+    (void) i02;
+    (void) i1;
 }
 
-static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+                         ggml_cuda_op_t op, bool src0_needs_f32, bool flatten_rows) {
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
     const int64_t ne02 = src0->ne[2];
     const int64_t ne03 = src0->ne[3];
+    const int64_t nrows0 = ggml_nrows(src0);
 
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
+    const bool use_src1 = src1 != nullptr;
+    const int64_t ne10 = use_src1 ? src1->ne[0] : 1;
+    const int64_t ne11 = use_src1 ? src1->ne[1] : 1;
+    const int64_t ne12 = use_src1 ? src1->ne[2] : 1;
+    const int64_t ne13 = use_src1 ? src1->ne[3] : 1;
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
 
     const int nb2  = dst->nb[2];
     const int nb3  = dst->nb[3];
-    const ggml_type type = src0->type;
-    const bool mul_mat_vec = ne11 == 1;
 
-    const float alpha = 1.0f;
-    const float beta = 0.0f;
-    const int x_ne = ne01 * ne00;
-    const int y_ne = ne11 * ne10;
-    const int d_ne = ne11 * ne01;
-    const int n_mm = ne03 * ne02;
-    const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
-
-    size_t x_size, y_size, d_size, q_size;
-    float * d_X = nullptr;
-    if (!mul_mat_vec) {
-        d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
-    }
-    float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
-    float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
-    char  * d_Q = (char  *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size);
-
-    const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(type);
-    dequantize_mul_mat_vec_cuda_t dmmv = ggml_get_dequantize_mul_mat_vec_cuda(type);
-    GGML_ASSERT(to_fp32_cuda != nullptr);
-
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            int i = i03*ne02 + i02;
-            cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
-            cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS];
-            cudaEvent_t  cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS];
-
-            float * c_Y = d_Y + i * y_ne;
-            float * c_D = d_D + i * d_ne;
-            char  * c_Q = d_Q + i * q_sz;
-
-            // copy src0 to device if necessary
-            if (src0->backend == GGML_BACKEND_CPU) {
-                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
-            } else if (src0->backend == GGML_BACKEND_CUDA) {
-                c_Q = ((char *) src0->data) + i * q_sz;
+    GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT);
+    GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
+
+    // strides for iteration over dims 3 and 2
+    const int64_t num_iters = flatten_rows ? 1 : ne02 * ne03;
+    const int64_t stride_mod = flatten_rows ? ne02 * ne03 : 1;
+    const int64_t src0_stride = ne00 * ne01 * stride_mod;
+    const int64_t src1_stride = ne10 * ne11 * stride_mod;
+    const int64_t dst_stride = ne0 * ne1 * stride_mod;
+
+    const size_t src0_ts = ggml_type_size(src0->type);
+    const size_t src0_bs = ggml_blck_size(src0->type);
+
+    struct ggml_tensor_extra_gpu * src0_extra =            (ggml_tensor_extra_gpu *) src0->extra;
+    struct ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
+    struct ggml_tensor_extra_gpu * dst_extra  =            (ggml_tensor_extra_gpu *) dst->extra;
+
+    const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
+    const bool src0_is_contiguous = ggml_is_contiguous(src0);
+    const bool src0_is_f32 = src0->type == GGML_TYPE_F32;
+
+    const bool src1_is_contiguous = use_src1 && ggml_is_contiguous(src1);
+    const bool src1_stays_on_host = use_src1 && (
+        dst->op == GGML_OP_SCALE || dst->op == GGML_OP_DIAG_MASK_INF || dst->op == GGML_OP_ROPE);
+
+    const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
+
+    const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
+
+    // dd = data device
+    char  * src0_ddq[GGML_CUDA_MAX_DEVICES] = {nullptr}; // quantized
+    float * src0_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; // float
+    float * src1_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr};
+    float *  dst_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr};
+
+    // asq = actual size quantized, asf = actual size float
+    size_t src0_asq[GGML_CUDA_MAX_DEVICES] = {0};
+    size_t src0_asf[GGML_CUDA_MAX_DEVICES] = {0};
+    size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
+    size_t  dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
+
+    // if multiple GPUs are used they need to wait for the main GPU to finish
+    if (split && g_device_count > 1) {
+        CUDA_CHECK(cudaSetDevice(g_main_device));
+        CUDA_CHECK(cudaDeviceSynchronize());
+    }
+
+    for (int id = 0; id < g_device_count; ++id) {
+        if (!split && id != g_main_device) {
+            continue;
+        }
+
+        const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU && id == g_main_device;
+        const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
+
+        int64_t row_low, row_high;
+        if (split) {
+            row_low = id == 0 ? 0 : nrows0*g_tensor_split[id];
+            row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1];
+        } else {
+            row_low = 0;
+            row_high = nrows0;
+        }
+        if (row_low == row_high) {
+            continue;
+        }
+
+        int64_t row_diff = row_high - row_low;
+
+        cudaSetDevice(id);
+
+        if (src0_on_device && src0_is_contiguous) {
+            if (src0_is_f32) {
+                src0_ddf[id] = (float *) src0_extra->data_device[id];
+            } else {
+                src0_ddq[id] = (char *) src0_extra->data_device[id];
+            }
+        } else {
+            if (src0_is_f32) {
+                src0_ddf[id] = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_asf[id]);
             } else {
-                GGML_ASSERT(false);
+                src0_ddq[id] = (char *) ggml_cuda_pool_malloc(row_diff*ne00 * src0_ts/src0_bs, &src0_asq[id]);
             }
-            if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
-                CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
+        }
 
-                // copy src1 to device
-                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
+        if (src0_needs_f32 && !src0_is_f32) {
+            src0_ddf[id] = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_asf[id]);
+        }
 
-                // wait for data
-                CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
+        if (use_src1 && !src1_stays_on_host) {
+            if (src1_on_device && src1_is_contiguous) {
+                src1_ddf[id] = (float *) src1_extra->data_device[id];
+            } else {
+                src1_ddf[id] = (float *) ggml_cuda_pool_malloc(num_iters*src1_stride * sizeof(float), &src1_asf[id]);
+            }
+        }
+        if (dst_on_device) {
+            dst_ddf[id] = (float *) dst_extra->data_device[id];
+        } else {
+            size_t size_dst_ddf = split ? row_diff*ne1 * sizeof(float) : num_iters*dst_stride * sizeof(float);
+            dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]);
+        }
 
-                // compute
-                dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
-                CUDA_CHECK(cudaGetLastError());
+        const int64_t i03_max = flatten_rows ? 1 : ne03;
+        const int64_t i02_max = flatten_rows ? 1 : ne02;
+        const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01;
 
-            } else { // general dequantization kernel + cuBLAS matrix matrix multiplication
-                float * c_X = d_X + i * x_ne;
+        for (int64_t i03 = 0; i03 < i03_max; i03++) {
+            const int64_t i13 = i03 % ne13;
+            for (int64_t i02 = 0; i02 < i02_max; i02++) {
+                const int64_t i12 = i02 % ne12;
 
-                // convert src0 to fp32 on device
-                to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
-                CUDA_CHECK(cudaGetLastError());
-                CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
+                const int64_t i0 = i03*ne02 + i02;
 
-                // copy src1 to device
-                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
+                // i0 values that contain the lower/upper rows for a split tensor when using multiple GPUs
+                const int64_t i0_offset_low = row_low/rows_per_iter;
+                const int64_t i0_offset_high = row_high/rows_per_iter;
 
-                // wait for conversion
-                CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
+                int64_t i01_low = 0;
+                int64_t i01_high = rows_per_iter;
+                if (split) {
+                    if (i0 < i0_offset_low || i0 > i0_offset_high) {
+                        continue;
+                    }
+                    if (i0 == i0_offset_low) {
+                        i01_low = row_low % rows_per_iter;
+                    }
+                    if (i0 == i0_offset_high) {
+                        i01_high = row_high % rows_per_iter;
+                    }
+                }
 
-                // compute
-                CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
-                CUBLAS_CHECK(
-                    cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
-                            ne01, ne11, ne10,
-                            &alpha, c_X, ne00,
-                                    c_Y, ne10,
-                            &beta,  c_D, ne01));
-            }
+                // There is possibly a bug in the Windows nvcc compiler regarding instruction reordering or optimizing out local variables.
+                // Removing the first assert or changing the order of the arguments causes the second assert to fail.
+                // Removing both asserts results in i01_high becoming 0 which in turn results in garbage output.
+                // The root cause seems to be a problem with i0_offset_high becoming 0 when it should always be >0 (for single GPU).
+                GGML_ASSERT(i01_low == 0 || g_device_count > 1);
+                GGML_ASSERT(i01_high == rows_per_iter || g_device_count > 1);
 
-            // copy dst to host
-            float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
-            CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
+                const int64_t i01_diff = i01_high - i01_low;
+                if (i01_diff == 0) {
+                    continue;
+                }
+                const int64_t i11 = i13*ne12 + i12;
+
+                cudaStream_t cudaStream_main = g_cudaStreams_main[id];
+
+                // for split tensors the data begins at i0 == i0_offset_low
+                char  * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
+                float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride;
+                float * src1_ddf_i = src1_ddf[id] + i11*src1_stride;
+                float * dst_ddf_i  =  dst_ddf[id] + (i0 - i0_offset_low)*dst_stride;
+
+                // for split tensors the data pointer needs to be rounded down
+                // to the bin edge for i03, i02 bins beyond the first
+                if (i0 - i0_offset_low > 0) {
+                    GGML_ASSERT(!flatten_rows);
+                    src0_ddq_i -= (row_low % ne01)*ne00 * src0_ts/src0_bs;
+                    src0_ddf_i -= (row_low % ne01)*ne00;
+                    dst_ddf_i  -= (row_low % ne0)*ne1;
+                }
+
+                // the main device memory buffer can be on VRAM scratch, with space for all partial results
+                // in that case an offset on dst_ddf_i is needed
+                if (dst->backend == GGML_BACKEND_GPU && id == g_main_device) {
+                    dst_ddf_i += i01_low; // offset is 0 if no tensor split
+                }
+
+                // copy src0, src1 to device if necessary
+                if (use_src1 && !src1_stays_on_host) {
+                    if (src1->backend == GGML_BACKEND_CPU) {
+                        GGML_ASSERT(!flatten_rows || nrows0 == ggml_nrows(src1));
+                        int64_t nrows1 = flatten_rows ? nrows0 : ne11;
+                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_main));
+                    } else if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
+                        if (id != g_main_device) {
+                            GGML_ASSERT(!flatten_rows);
+                            float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device];
+                            src1_ddf_i_source += i11*src1_stride;
+                            CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_stride*sizeof(float),
+                                                    cudaMemcpyDeviceToDevice, cudaStream_main));
+                        }
+                    } else if (src1_on_device && !src1_is_contiguous) {
+                        GGML_ASSERT(!split);
+                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, ne11, cudaStream_main));
+                    } else {
+                        GGML_ASSERT(false);
+                    }
+                }
+
+                if (!src0_on_device || !src0_is_contiguous) {
+                    if (src0_is_f32) {
+                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
+                    } else {
+                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
+                    }
+                }
+
+                // convert src0 to f32 if it is necessary for the ggml_cuda_op
+                if (src0_needs_f32 && !src0_is_f32) {
+                    to_fp32_cuda(src0_ddq_i, src0_ddf_i, i01_diff*ne00, cudaStream_main);
+                    CUDA_CHECK(cudaGetLastError());
+                }
+
+                // do the computation
+                op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
+
+                // copy dst to host or other device if necessary
+                if (!dst_on_device) {
+                    void * dst_off_device;
+                    cudaMemcpyKind kind;
+                    if (dst->backend == GGML_BACKEND_CPU) {
+                        dst_off_device = dst->data;
+                        kind = cudaMemcpyDeviceToHost;
+                    } else if (dst->backend == GGML_BACKEND_GPU) {
+                        dst_off_device = dst_extra->data_device[g_main_device];
+                        kind = cudaMemcpyDeviceToDevice;
+                    } else {
+                        GGML_ASSERT(false);
+                    }
+                    if (split) {
+                        // src0 = weight matrix is saved as a transposed matrix for better memory layout.
+                        // dst is NOT transposed.
+                        // The outputs of cuBLAS matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU.
+                        // Instead they need to be copied to the correct slice in ne0 = dst row index.
+                        // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.
+                        for (int64_t j = 0; j < ne1; ++j) {
+                            float * dhf_dst_i = (float *) ((char *) dst_off_device + (j*ne0 + i01_low)*sizeof(float) + i02*nb2 + i03*nb3);
+                            CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i + j*i01_diff, i01_diff*sizeof(float), kind, cudaStream_main));
+                        }
+                    } else {
+                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
+                        CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, dst_stride*sizeof(float), kind, cudaStream_main));
+                    }
+                }
+            }
         }
     }
 
-    CUDA_CHECK(cudaDeviceSynchronize());
-    if (!mul_mat_vec) {
-        ggml_cuda_pool_free(d_X, x_size);
+    // wait until each device is finished, then free their buffers
+    for (int id = 0; id < g_device_count; ++id) {
+        if (src0_asq[id] == 0 && src0_asf[id] == 0 && src1_asf[id] == 0 && dst_asf[id] == 0) {
+            continue;
+        }
+
+        CUDA_CHECK(cudaSetDevice(id));
+        CUDA_CHECK(cudaDeviceSynchronize());
+
+        if (src0_asq[id] > 0) {
+            ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]);
+        }
+        if (src0_asf[id] > 0) {
+            ggml_cuda_pool_free(src0_ddf[id], src0_asf[id]);
+        }
+        if (src1_asf[id] > 0) {
+            ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
+        }
+        if (dst_asf[id] > 0) {
+            ggml_cuda_pool_free(dst_ddf[id], dst_asf[id]);
+        }
     }
-    ggml_cuda_pool_free(d_Y, y_size);
-    ggml_cuda_pool_free(d_D, d_size);
-    ggml_cuda_pool_free(d_Q, q_size);
 }
 
-void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
+void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_mul_f32(src0, src1, dst);
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, true, true);
+}
+
+void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul, true, false); // TODO ggml_cuda_op needs modification for flatten
+}
+
+void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true, true);
+}
+
+void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true, true);
 }
 
 bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
@@ -815,111 +2310,414 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
     if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
         src1->type == GGML_TYPE_F32 &&
         dst->type == GGML_TYPE_F32 &&
-        ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_CUDA)) {
+        (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
         return true;
     }
 
     return false;
 }
 
-bool ggml_cuda_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
-    size_t src0_sz = ggml_nbytes(src0);
-    size_t src1_sz = ggml_nbytes(src1);
+void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
+    GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
+    GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
+    GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
+    GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+
+    CUDA_CHECK(cudaSetDevice(g_main_device));
+    cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
 
-    // mul_mat_q: src0 is converted to fp32 on device
-    size_t mul_mat_q_transfer = src0_sz + src1_sz;
+    struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+    void * src0_ddq = src0_extra->data_device[g_main_device];
 
-    // mul_mat_f16: src1 is converted to fp16 on cpu
-    size_t mul_mat_f16_transfer = src0_sz + sizeof(half) * ggml_nelements(src1);
+    struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
+    float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
 
-    // choose the smaller one to transfer to the device
-    // TODO: this is not always the best choice due to the overhead of converting to fp16
-    return mul_mat_f16_transfer < mul_mat_q_transfer;
+    struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
+    float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
+
+    ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main);
 }
 
-void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
-    GGML_ASSERT(ggml_cuda_can_mul_mat(src0, src1, dst));
+void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
+    GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1));
+    GGML_ASSERT(!ggml_is_permuted(src0));
+    GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
 
-    if (src0->type == GGML_TYPE_F32) {
-        ggml_cuda_mul_mat_f32(src0, src1, dst);
-    }
-    else if (src0->type == GGML_TYPE_F16) {
-        if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
-            ggml_cuda_mul_mat_f16(src0, src1, dst, wdata, wsize);
-        }
-        else {
-            ggml_cuda_mul_mat_q_f32(src0, src1, dst);
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+
+    const int64_t nb01 = src0->nb[1];
+    const int64_t nb02 = src0->nb[2];
+
+    CUDA_CHECK(cudaSetDevice(g_main_device));
+    cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
+
+    struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+    void * src0_ddq = src0_extra->data_device[g_main_device];
+
+    struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
+    float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
+
+    struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
+    float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
+
+    const int row_stride_x = nb01 / sizeof(half);
+    const int channel_stride_x = nb02 / sizeof(half);
+
+    ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main);
+}
+
+void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
+        src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
+
+    if (all_on_device && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
+        ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
+    } else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) {
+        ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
+    }else if (src0->type == GGML_TYPE_F32) {
+        ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
+    } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
+        if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[1] % GGML_CUDA_DMMV_Y == 0) {
+            ggml_cuda_op(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false, false);
+        } else {
+            ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
         }
+    } else {
+        GGML_ASSERT(false);
     }
-    else if (ggml_is_quantized(src0->type)) {
-        ggml_cuda_mul_mat_q_f32(src0, src1, dst);
-    }
-    else {
+}
+
+void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_scale, true, true);
+}
+
+void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    const int64_t ne = ggml_nelements(src0);
+    GGML_ASSERT(ne == ggml_nelements(src1));
+
+    GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
+    GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
+
+    GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
+    GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    GGML_ASSERT(src0->ne[3] == 1);
+
+    const int64_t nb00 = src0->nb[0];
+    const int64_t nb01 = src0->nb[1];
+    const int64_t nb02 = src0->nb[2];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    GGML_ASSERT(src1->ne[3] == 1);
+
+    const int64_t nb10 = src1->nb[0];
+    const int64_t nb11 = src1->nb[1];
+    const int64_t nb12 = src1->nb[2];
+
+    CUDA_CHECK(cudaSetDevice(g_main_device));
+    cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
+
+    const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+    const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
+
+    char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
+    char * src1_ddc = (char *) src1_extra->data_device[g_main_device];
+
+    if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
+        ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
+                              ne10, ne11, nb10, nb11, nb12, cudaStream_main);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
+        ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
+                              ne10, ne11, nb10, nb11, nb12, cudaStream_main);
+    } else {
         GGML_ASSERT(false);
     }
+
+    (void) dst;
+}
+
+void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_diag_mask_inf, true, true);
+}
+
+void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_soft_max, true, true);
+}
+
+void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, false); // FIXME flatten changes results
+}
+
+void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    (void) src0;
+    (void) src1;
+    (void) dst;
+}
+
+void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
+    int nrows = ggml_nrows(tensor);
+    const size_t nb1 = tensor->nb[1];
+    ggml_backend backend = tensor->backend;
+    struct ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu;
+    memset(extra, 0, sizeof(*extra));
+
+    for (int id = 0; id < g_device_count; ++id) {
+        if (backend == GGML_BACKEND_GPU && id != g_main_device) {
+            continue;
+        }
+
+        cudaSetDevice(id);
+
+        int row_low, row_high;
+        if (backend == GGML_BACKEND_GPU) {
+            row_low = 0;
+            row_high = nrows;
+        } else if (backend == GGML_BACKEND_GPU_SPLIT) {
+            row_low = id == 0 ? 0 : nrows*g_tensor_split[id];
+            row_high = id == g_device_count - 1 ? nrows : nrows*g_tensor_split[id + 1];
+        } else {
+            GGML_ASSERT(false);
+        }
+        if (row_low == row_high) {
+            continue;
+        }
+
+        int64_t nrows_split = row_high - row_low;
+
+        const size_t offset_split = row_low*nb1;
+        const size_t size = ggml_nbytes_split(tensor, nrows_split);
+
+        void * buf;
+        CUDA_CHECK(cudaMalloc(&buf, size));
+        void * buf_host = (char*)data + offset_split;
+
+        cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
+
+        extra->data_device[id] = buf;
+    }
+
+    tensor->extra = extra;
 }
 
-size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
-    if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
-        return ggml_nelements(src1) * sizeof(ggml_fp16_t);
+void ggml_cuda_free_data(struct ggml_tensor * tensor) {
+    if (tensor->backend != GGML_BACKEND_GPU && tensor->backend != GGML_BACKEND_GPU_SPLIT) {
+        return;
     }
-    else {
-        return 0;
+
+    ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
+
+    for (int id = 0; id < g_device_count; ++id) {
+        if (extra->data_device[id] == nullptr) {
+            continue;
+        }
+
+        CUDA_CHECK(cudaSetDevice(id));
+        CUDA_CHECK(cudaFree(extra->data_device[id]));
     }
+
+    delete extra;
 }
 
-void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
-    const int64_t ne0 = tensor->ne[0];
-    const int64_t ne1 = tensor->ne[1];
-    const int64_t ne2 = tensor->ne[2];
-    const int64_t ne3 = tensor->ne[3];
+void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
+    if (scratch && g_scratch_size == 0) {
+        return;
+    }
+
+    // recursively assign CUDA buffers until a compute tensor is found
+    if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) {
+        const ggml_op src0_op = tensor->src0->op;
+        if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
+            ggml_cuda_assign_buffers_impl(tensor->src0, scratch);
+        }
+    }
+    if (tensor->op == GGML_OP_CPY && tensor->src1->backend == GGML_BACKEND_CPU) {
+        ggml_cuda_assign_buffers_impl(tensor->src1, scratch);
+    }
 
-    const ggml_type type = tensor->type;
-    const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);
+    tensor->backend = GGML_BACKEND_GPU;
+    struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
 
-    size_t q_size;
-    char * dst = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
+    const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) ||
+        tensor->op == GGML_OP_VIEW;
+    const size_t size = ggml_nbytes(tensor);
 
-    cudaStream_t cudaStream2 = g_cudaStreams2[0];
+    CUDA_CHECK(cudaSetDevice(g_main_device));
+    if (inplace && tensor->src0->backend == GGML_BACKEND_GPU) {
+        struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra;
+        char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
+        size_t offset = 0;
+        if (tensor->op == GGML_OP_VIEW) {
+            memcpy(&offset, tensor->opt[0]->data, sizeof(size_t));
+        }
+        extra->data_device[g_main_device] = src0_ddc + offset;
+    } else if (tensor->op == GGML_OP_CPY) {
+        struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src1->extra;
+        void * src1_ddv = src1_extra->data_device[g_main_device];
+        extra->data_device[g_main_device] = src1_ddv;
+    } else if (scratch) {
+        GGML_ASSERT(size <= g_scratch_size);
+        if (g_scratch_offset + size > g_scratch_size) {
+            g_scratch_offset = 0;
+        }
 
-    // copy tensor to device
-    for (int64_t i3 = 0; i3 < ne3; i3++) {
-        for (int64_t i2 = 0; i2 < ne2; i2++) {
-            int i = i3*ne2 + i2;
-            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(dst + i*ne0*ne1, tensor, i3, i2, cudaStream2));
+        char * data = (char *) g_scratch_buffer;
+        if (data == nullptr) {
+            CUDA_CHECK(cudaMalloc(&data, g_scratch_size));
+            g_scratch_buffer = data;
         }
+        extra->data_device[g_main_device] = data + g_scratch_offset;
+
+        g_scratch_offset += size;
+
+        GGML_ASSERT(g_scratch_offset <= g_scratch_size);
+    } else { // allocate new buffers outside of scratch
+        void * data;
+        CUDA_CHECK(cudaMalloc(&data, size));
+        CUDA_CHECK(cudaMemset(data, 0, size));
+        extra->data_device[g_main_device] = data;
     }
 
-    tensor->data = dst;
-    tensor->backend = GGML_BACKEND_CUDA;
+    tensor->extra = extra;
 }
 
-void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset) {
-    FILE * fp = fopen(fname, "rb");
+void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
+    ggml_cuda_assign_buffers_impl(tensor, true);
+}
 
-    const size_t size = ggml_nbytes(tensor);
+void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
+    ggml_cuda_assign_buffers_impl(tensor, false);
+}
 
-    void * buf;
-    CUDA_CHECK(cudaMalloc(&buf, size));
-    void * buf_host = malloc(size);
+void ggml_cuda_set_main_device(int main_device) {
+    if (main_device >= g_device_count) {
+        fprintf(stderr, "warning: cannot set main_device=%d because there are only %d devices. Using device %d instead.\n",
+                main_device, g_device_count, g_main_device);
+        return;
+    }
+    g_main_device = main_device;
+    if (g_device_count > 1) {
+        cudaDeviceProp prop;
+        CUDA_CHECK(cudaGetDeviceProperties(&prop, g_main_device));
+        fprintf(stderr, "%s: using device %d (%s) as main device\n", __func__, g_main_device, prop.name);
+    }
+}
 
-#ifdef _WIN32
-    int ret = _fseeki64(fp, (__int64) offset, SEEK_SET);
-#else
-    int ret = fseek(fp, (long) offset, SEEK_SET);
-#endif
-    GGML_ASSERT(ret == 0); // same
+void ggml_cuda_set_scratch_size(size_t scratch_size) {
+    g_scratch_size = scratch_size;
+}
 
-    size_t ret2 = fread(buf_host, size, 1, fp);
-    if (ret2 != 1) {
-        fprintf(stderr, "unexpectedly reached end of file");
-        exit(1);
+void ggml_cuda_free_scratch() {
+    if (g_scratch_buffer == nullptr) {
+        return;
     }
 
-    cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
-    cudaDeviceSynchronize();
+    CUDA_CHECK(cudaFree(g_scratch_buffer));
+    g_scratch_buffer = nullptr;
+}
+
+bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){
+    ggml_cuda_func_t func;
+    const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
+        || tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT
+        || (tensor->src1 != nullptr && tensor->src1->backend == GGML_BACKEND_GPU);
+
+    switch (tensor->op) {
+        case GGML_OP_ADD:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_add;
+            break;
+        case GGML_OP_MUL:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_mul;
+            break;
+        case GGML_OP_SILU:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_silu;
+            break;
+        case GGML_OP_RMS_NORM:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_rms_norm;
+            break;
+        case GGML_OP_MUL_MAT:
+            if (!any_on_device && !ggml_cuda_can_mul_mat(tensor->src0, tensor->src1, tensor)) {
+                return false;
+            }
+            func = ggml_cuda_mul_mat;
+            break;
+        case GGML_OP_SCALE:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_scale;
+            break;
+        case GGML_OP_CPY:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_cpy;
+            break;
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_TRANSPOSE:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_nop;
+            break;
+        case GGML_OP_DIAG_MASK_INF:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_diag_mask_inf;
+            break;
+        case GGML_OP_SOFT_MAX:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_soft_max;
+            break;
+        case GGML_OP_ROPE:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_rope;
+            break;
+        default:
+            return false;
+    }
 
-    tensor->data = buf;
-    free(buf_host);
-    fclose(fp);
+    if (params->ith != 0) {
+        return true;
+    }
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return true;
+    }
+    func(tensor->src0, tensor->src1, tensor);
+    return true;
 }
index 6a04dde6c37a9d3179747575ecb2dee455a29720..d32b4484267ab5a3f9028db07330302432c1fd93 100644 (file)
@@ -1,10 +1,19 @@
+#pragma once
+
 #include "ggml.h"
 
 #ifdef  __cplusplus
 extern "C" {
 #endif
 
+#define GGML_CUDA_MAX_DEVICES       16
+
+struct ggml_tensor_extra_gpu {
+    void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
+};
+
 void   ggml_init_cublas(void);
+void   ggml_cuda_set_tensor_split(const float * tensor_split);
 
 void   ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
 bool   ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
@@ -15,8 +24,15 @@ void   ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
 void * ggml_cuda_host_malloc(size_t size);
 void   ggml_cuda_host_free(void * ptr);
 
-void ggml_cuda_transform_tensor(struct ggml_tensor * tensor);
-void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensors, size_t offset);
+void   ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
+
+void   ggml_cuda_free_data(struct ggml_tensor * tensor);
+void   ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
+void   ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
+void   ggml_cuda_set_main_device(int main_device);
+void   ggml_cuda_set_scratch_size(size_t scratch_size);
+void   ggml_cuda_free_scratch(void);
+bool   ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
 
 #ifdef  __cplusplus
 }
index 7bcc603ef8432f90a48e279168ba577f2b8468fd..a92b445c9d7660da6ed5dfcbc4cce9ae7a5b9827 100644 (file)
@@ -1,23 +1,24 @@
 #pragma once
 
+#include "ggml.h"
+
 #ifdef  __cplusplus
 extern "C" {
 #endif
 
 void ggml_cl_init(void);
 
-enum ggml_blas_order {
-    GGML_BLAS_ORDER_ROW_MAJOR = 101,
-    GGML_BLAS_ORDER_COLUMN_MAJOR = 102,
-};
+void   ggml_cl_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
+bool   ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
+size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
+void   ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
+
+void * ggml_cl_host_malloc(size_t size);
+void   ggml_cl_host_free(void * ptr);
 
-enum ggml_blas_op {
-    GGML_BLAS_OP_N = 111,
-    GGML_BLAS_OP_T = 112,
-    GGML_BLAS_OP_C = 113,
-};
+void ggml_cl_free_data(const struct ggml_tensor* tensor);
 
-void ggml_cl_sgemm_wrapper(const enum ggml_blas_order order, const enum ggml_blas_op trans_a, const enum ggml_blas_op trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype);
+void ggml_cl_transform_tensor(void * data, struct ggml_tensor * tensor);
 
 #ifdef  __cplusplus
 }
diff --git a/ggml.c b/ggml.c
index 77a3d89f748e0bb67d4d82218f67bc0bc172e706..955f335cd18275174f2e3b338db013dc70aa7c2d 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -1,8 +1,12 @@
-// Defines CLOCK_MONOTONIC on Linux
-#define _GNU_SOURCE
+#define _GNU_SOURCE // Defines CLOCK_MONOTONIC on Linux
+#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
 
 #include "ggml.h"
 
+#ifdef GGML_USE_K_QUANTS
+#include "k_quants.h"
+#endif
+
 #if defined(_MSC_VER) || defined(__MINGW32__)
 #include <malloc.h> // using malloc.h with MSC/MINGW
 #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
 #include <stdio.h>
 #include <float.h>
 #include <limits.h>
+#include <stdarg.h>
+
+#ifdef GGML_USE_METAL
+#include <unistd.h>
+#endif
 
 // if C99 - static_assert is noop
 // ref: https://stackoverflow.com/a/53923785/4039976
 #define static_assert(cond, msg) struct global_scope_noop_trick
 #endif
 
+#if defined(_MSC_VER)
+// disable "possible loss of data" to avoid hundreds of casts
+// we should just be careful :)
+#pragma warning(disable: 4244 4267)
+#endif
+
 #if defined(_WIN32)
 
 #include <windows.h>
@@ -98,6 +113,7 @@ typedef void* thread_ret_t;
 /*#define GGML_PERF*/
 #define GGML_DEBUG 0
 #define GGML_GELU_FP16
+#define GGML_GELU_QUICK_FP16
 #define GGML_SILU_FP16
 
 #define GGML_SOFT_MAX_UNROLL 4
@@ -115,15 +131,58 @@ typedef void* thread_ret_t;
     #define GGML_MEM_ALIGN 16
 #endif
 
+//
+// logging
+//
+
+#if (GGML_DEBUG >= 1)
+#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG(...)
+#endif
+
+#if (GGML_DEBUG >= 5)
+#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_5(...)
+#endif
+
+#if (GGML_DEBUG >= 10)
+#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_10(...)
+#endif
+
+#define GGML_PRINT(...) printf(__VA_ARGS__)
+
+//
+// end of logging block
+//
+
 #if defined(_MSC_VER) || defined(__MINGW32__)
 #define GGML_ALIGNED_MALLOC(size)  _aligned_malloc(size, GGML_MEM_ALIGN)
 #define GGML_ALIGNED_FREE(ptr)     _aligned_free(ptr)
 #else
 inline static void* ggml_aligned_malloc(size_t size) {
     void* aligned_memory = NULL;
+#ifdef GGML_USE_METAL
+    int result = posix_memalign(&aligned_memory, getpagesize(), size);
+#else
     int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
+#endif
     if (result != 0) {
         // Handle allocation failure
+        const char *error_desc = "unknown allocation error";
+        switch (result) {
+            case EINVAL:
+                error_desc = "invalid alignment value";
+                break;
+            case ENOMEM:
+                error_desc = "insufficient memory";
+                break;
+        }
+        GGML_PRINT("%s: %s (attempted to allocate %6.2f MB)\n",
+            __func__, error_desc, size/(1024.0*1024.0));
         return NULL;
     }
     return aligned_memory;
@@ -186,10 +245,12 @@ typedef double ggml_float;
 #if defined(_MSC_VER) || defined(__MINGW32__)
 #include <intrin.h>
 #else
+#if !defined(__riscv)
 #include <immintrin.h>
 #endif
 #endif
 #endif
+#endif
 
 #ifdef __F16C__
 
@@ -320,6 +381,9 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
 // precomputed gelu table for f16 (128 KB)
 static ggml_fp16_t table_gelu_f16[1 << 16];
 
+// precomputed quick gelu table for f16 (128 KB)
+static ggml_fp16_t table_gelu_quick_f16[1 << 16];
+
 // precomputed silu table for f16 (128 KB)
 static ggml_fp16_t table_silu_f16[1 << 16];
 
@@ -401,21 +465,27 @@ void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n) {
 //
 
 #if defined(_MSC_VER) || defined(__MINGW32__)
-static int64_t timer_freq;
+static int64_t timer_freq, timer_start;
 void ggml_time_init(void) {
-    LARGE_INTEGER frequency;
-    QueryPerformanceFrequency(&frequency);
-    timer_freq = frequency.QuadPart;
+    LARGE_INTEGER t;
+    QueryPerformanceFrequency(&t);
+    timer_freq = t.QuadPart;
+
+    // The multiplication by 1000 or 1000000 below can cause an overflow if timer_freq
+    // and the uptime is high enough.
+    // We subtract the program start time to reduce the likelihood of that happening.
+    QueryPerformanceCounter(&t);
+    timer_start = t.QuadPart;
 }
 int64_t ggml_time_ms(void) {
     LARGE_INTEGER t;
     QueryPerformanceCounter(&t);
-    return (t.QuadPart * 1000) / timer_freq;
+    return ((t.QuadPart-timer_start) * 1000) / timer_freq;
 }
 int64_t ggml_time_us(void) {
     LARGE_INTEGER t;
     QueryPerformanceCounter(&t);
-    return (t.QuadPart * 1000000) / timer_freq;
+    return ((t.QuadPart-timer_start) * 1000000) / timer_freq;
 }
 #else
 void ggml_time_init(void) {}
@@ -472,6 +542,8 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
 // quantization
 //
 
+#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
+
 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
 // multiply int8_t, add results pairwise twice
 static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
@@ -531,7 +603,7 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
 static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
 {
     const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
-    const __m256i bytes = _mm256_set_m128i(_mm_srli_epi16(tmp, 4), tmp);
+    const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
     const __m256i lowMask = _mm256_set1_epi8( 0xF );
     return _mm256_and_si256(lowMask, bytes);
 }
@@ -604,7 +676,7 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
     bytesh = _mm_or_si128(bytesh, bit_mask);
     bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
     bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
-    return _mm256_set_m128i(bytesh, bytesl);
+    return MM256_SET_M128I(bytesh, bytesl);
 }
 
 // Unpack 32 4-bit fields into 32 bytes
@@ -617,7 +689,7 @@ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
     const __m128i lowMask = _mm_set1_epi8(0xF);
     tmpl = _mm_and_si128(lowMask, tmpl);
     tmph = _mm_and_si128(lowMask, tmph);
-    return _mm256_set_m128i(tmph, tmpl);
+    return MM256_SET_M128I(tmph, tmpl);
 }
 
 // add int16_t pairwise and return as float vector
@@ -625,7 +697,7 @@ static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
     const __m128i ones = _mm_set1_epi16(1);
     const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
     const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
-    const __m256i summed_pairs = _mm256_set_m128i(summed_pairsh, summed_pairsl);
+    const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
     return _mm256_cvtepi32_ps(summed_pairs);
 }
 
@@ -1563,6 +1635,48 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
         .vec_dot_q                = NULL,   // TODO
         .vec_dot_type             = GGML_TYPE_Q8_1,
     },
+#ifdef GGML_USE_K_QUANTS
+    [GGML_TYPE_Q2_K] = {
+        .dequantize_row_q         = (dequantize_row_q_t) dequantize_row_q2_K,
+        .quantize_row_q           = quantize_row_q2_K,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q2_K_reference,
+        .quantize_row_q_dot       = quantize_row_q8_K,
+        .vec_dot_q                = ggml_vec_dot_q2_K_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+    },
+    [GGML_TYPE_Q3_K] = {
+        .dequantize_row_q         = (dequantize_row_q_t) dequantize_row_q3_K,
+        .quantize_row_q           = quantize_row_q3_K,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q3_K_reference,
+        .quantize_row_q_dot       = quantize_row_q8_K,
+        .vec_dot_q                = ggml_vec_dot_q3_K_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+    },
+    [GGML_TYPE_Q4_K] = {
+        .dequantize_row_q         = (dequantize_row_q_t) dequantize_row_q4_K,
+        .quantize_row_q           = quantize_row_q4_K,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_K_reference,
+        .quantize_row_q_dot       = quantize_row_q8_K,
+        .vec_dot_q                = ggml_vec_dot_q4_K_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+    },
+    [GGML_TYPE_Q5_K] = {
+        .dequantize_row_q         = (dequantize_row_q_t) dequantize_row_q5_K,
+        .quantize_row_q           = quantize_row_q5_K,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_K_reference,
+        .quantize_row_q_dot       = quantize_row_q8_K,
+        .vec_dot_q                = ggml_vec_dot_q5_K_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+    },
+    [GGML_TYPE_Q6_K] = {
+        .dequantize_row_q         = (dequantize_row_q_t) dequantize_row_q6_K,
+        .quantize_row_q           = quantize_row_q6_K,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q6_K_reference,
+        .quantize_row_q_dot       = quantize_row_q8_K,
+        .vec_dot_q                = ggml_vec_dot_q6_K_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+    },
+#endif
 };
 
 // For internal test use
@@ -1607,14 +1721,17 @@ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
 #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
 #define GGML_F32x4_REDUCE(res, x)              \
 {                                              \
-    for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
-        x[2*i] = vaddq_f32(x[2*i], x[2*i+1]);  \
+    int offset = GGML_F32_ARR >> 1;            \
+    for (int i = 0; i < offset; ++i) {         \
+        x[i] = vaddq_f32(x[i], x[offset+i]);   \
     }                                          \
-    for (int i = 0; i < GGML_F32_ARR/4; ++i) { \
-        x[4*i] = vaddq_f32(x[4*i], x[4*i+2]);  \
+    offset >>= 1;                              \
+    for (int i = 0; i < offset; ++i) {         \
+        x[i] = vaddq_f32(x[i], x[offset+i]);   \
     }                                          \
-    for (int i = 0; i < GGML_F32_ARR/8; ++i) { \
-        x[8*i] = vaddq_f32(x[8*i], x[8*i+4]);  \
+    offset >>= 1;                              \
+    for (int i = 0; i < offset; ++i) {         \
+        x[i] = vaddq_f32(x[i], x[offset+i]);   \
     }                                          \
     res = GGML_F32x4_REDUCE_ONE(x[0]);         \
 }
@@ -1645,14 +1762,17 @@ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
     #define GGML_F16x8_MUL          vmulq_f16
     #define GGML_F16x8_REDUCE(res, x)                             \
     {                                                             \
-        for (int i = 0; i < GGML_F16_ARR/2; ++i) {                \
-            x[2*i] = vaddq_f16(x[2*i], x[2*i+1]);                 \
+        int offset = GGML_F16_ARR >> 1;                           \
+        for (int i = 0; i < offset; ++i) {                        \
+            x[i] = vaddq_f16(x[i], x[offset+i]);                  \
         }                                                         \
-        for (int i = 0; i < GGML_F16_ARR/4; ++i) {                \
-            x[4*i] = vaddq_f16(x[4*i], x[4*i+2]);                 \
+        offset >>= 1;                                             \
+        for (int i = 0; i < offset; ++i) {                        \
+            x[i] = vaddq_f16(x[i], x[offset+i]);                  \
         }                                                         \
-        for (int i = 0; i < GGML_F16_ARR/8; ++i) {                \
-            x[8*i] = vaddq_f16(x[8*i], x[8*i+4]);                 \
+        offset >>= 1;                                             \
+        for (int i = 0; i < offset; ++i) {                        \
+            x[i] = vaddq_f16(x[i], x[offset+i]);                  \
         }                                                         \
         const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
         const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
@@ -1719,14 +1839,17 @@ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
 #define GGML_F32x8_MUL     _mm256_mul_ps
 #define GGML_F32x8_REDUCE(res, x)                                 \
 {                                                                 \
-    for (int i = 0; i < GGML_F32_ARR/2; ++i) {                    \
-        x[2*i] = _mm256_add_ps(x[2*i], x[2*i+1]);                 \
+    int offset = GGML_F32_ARR >> 1;                               \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \
     }                                                             \
-    for (int i = 0; i < GGML_F32_ARR/4; ++i) {                    \
-        x[4*i] = _mm256_add_ps(x[4*i], x[4*i+2]);                 \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \
     }                                                             \
-    for (int i = 0; i < GGML_F32_ARR/8; ++i) {                    \
-        x[8*i] = _mm256_add_ps(x[8*i], x[8*i+4]);                 \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \
     }                                                             \
     const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]),    \
                                  _mm256_extractf128_ps(x[0], 1)); \
@@ -1816,14 +1939,17 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
 #define GGML_F32x4_MUL          vec_mul
 #define GGML_F32x4_REDUCE(res, x)              \
 {                                              \
-    for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
-        x[2*i] = vec_add(x[2*i], x[2*i+1]);    \
+    int offset = GGML_F32_ARR >> 1;            \
+    for (int i = 0; i < offset; ++i) {         \
+        x[i] = vec_add(x[i], x[offset+i]);     \
     }                                          \
-    for (int i = 0; i < GGML_F32_ARR/4; ++i) { \
-        x[4*i] = vec_add(x[4*i], x[4*i+2]);    \
+    offset >>= 1;                              \
+    for (int i = 0; i < offset; ++i) {         \
+        x[i] = vec_add(x[i], x[offset+i]);     \
     }                                          \
-    for (int i = 0; i < GGML_F32_ARR/8; ++i) { \
-        x[8*i] = vec_add(x[8*i], x[8*i+4]);    \
+    offset >>= 1;                              \
+    for (int i = 0; i < offset; ++i) {         \
+        x[i] = vec_add(x[i], x[offset+i]);     \
     }                                          \
     res = vec_extract(x[0], 0) +               \
           vec_extract(x[0], 1) +               \
@@ -1879,14 +2005,17 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
 #define GGML_F32x4_MUL          wasm_f32x4_mul
 #define GGML_F32x4_REDUCE(res, x)                  \
 {                                                  \
-    for (int i = 0; i < GGML_F32_ARR/2; ++i) {     \
-        x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \
+    int offset = GGML_F32_ARR >> 1;                \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
     }                                              \
-    for (int i = 0; i < GGML_F32_ARR/4; ++i) {     \
-        x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \
+    offset >>= 1;                                  \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
     }                                              \
-    for (int i = 0; i < GGML_F32_ARR/8; ++i) {     \
-        x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \
+    offset >>= 1;                                  \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
     }                                              \
     res = wasm_f32x4_extract_lane(x[0], 0) +       \
           wasm_f32x4_extract_lane(x[0], 1) +       \
@@ -1941,14 +2070,17 @@ inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
 #define GGML_F16x4_MUL         wasm_f32x4_mul
 #define GGML_F16x4_REDUCE(res, x)                  \
 {                                                  \
-    for (int i = 0; i < GGML_F16_ARR/2; ++i) {     \
-        x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \
+    int offset = GGML_F16_ARR >> 1;                \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
     }                                              \
-    for (int i = 0; i < GGML_F16_ARR/4; ++i) {     \
-        x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \
+    offset >>= 1;                                  \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
     }                                              \
-    for (int i = 0; i < GGML_F16_ARR/8; ++i) {     \
-        x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \
+    offset >>= 1;                                  \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
     }                                              \
     res = wasm_f32x4_extract_lane(x[0], 0) +       \
           wasm_f32x4_extract_lane(x[0], 1) +       \
@@ -1990,14 +2122,17 @@ inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
 #define GGML_F32x4_MUL     _mm_mul_ps
 #define GGML_F32x4_REDUCE(res, x)                                 \
 {                                                                 \
-    for (int i = 0; i < GGML_F32_ARR/2; ++i) {                    \
-        x[2*i] = _mm_add_ps(x[2*i], x[2*i+1]);                    \
+    int offset = GGML_F32_ARR >> 1;                               \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \
     }                                                             \
-    for (int i = 0; i < GGML_F32_ARR/4; ++i) {                    \
-        x[4*i] = _mm_add_ps(x[4*i], x[4*i+2]);                    \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \
     }                                                             \
-    for (int i = 0; i < GGML_F32_ARR/8; ++i) {                    \
-        x[8*i] = _mm_add_ps(x[8*i], x[8*i+4]);                    \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \
     }                                                             \
     const __m128 t0 = _mm_hadd_ps(x[0], x[0]);                    \
     res = _mm_cvtss_f32(_mm_hadd_ps(t0, t0));                     \
@@ -2288,7 +2423,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
 
         // Convert int32_t to float
-        __m256 p = _mm256_cvtepi32_ps(_mm256_set_m128i(i32_0, i32_1));
+        __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
 
         // Apply the scale, and accumulate
         acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
@@ -2764,7 +2899,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
         __m128i bxh = _mm256_extractf128_si256(bx, 1);
         bxl = _mm_or_si128(bxl, bxhil);
         bxh = _mm_or_si128(bxh, bxhih);
-        bx = _mm256_set_m128i(bxh, bxl);
+        bx = MM256_SET_M128I(bxh, bxl);
 
         const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
 
@@ -3020,7 +3155,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
         __m128i bxh = _mm256_extractf128_si256(bx, 1);
         bxl = _mm_or_si128(bxl, bxhil);
         bxh = _mm_or_si128(bxh, bxhih);
-        bx = _mm256_set_m128i(bxh, bxl);
+        bx = MM256_SET_M128I(bxh, bxl);
 
         const __m256 dy = _mm256_set1_ps(y[i].d);
         const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
@@ -3286,6 +3421,7 @@ inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) {
 inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
 
 static const float GELU_COEF_A    = 0.044715f;
+static const float GELU_QUICK_COEF    = -1.702f;
 static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
 
 inline static float ggml_gelu_f32(float x) {
@@ -3316,6 +3452,34 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
 }
 #endif
 
+inline static float ggml_gelu_quick_f32(float x) {
+    return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
+}
+
+//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+//    const uint16_t * i16 = (const uint16_t *) x;
+//    for (int i = 0; i < n; ++i) {
+//        y[i] = table_gelu_quick_f16[i16[i]];
+//    }
+//}
+
+#ifdef GGML_GELU_QUICK_FP16
+inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
+    uint16_t t;
+    for (int i = 0; i < n; ++i) {
+        ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
+        memcpy(&t, &fp16, sizeof(uint16_t));
+        y[i] = GGML_FP16_TO_FP32(table_gelu_quick_f16[t]);
+    }
+}
+#else
+inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = ggml_gelu_quick_f32(x[i]);
+    }
+}
+#endif
+
 // Sigmoid Linear Unit (SiLU) function
 inline static float ggml_silu_f32(float x) {
     return x/(1.0f + expf(-x));
@@ -3405,30 +3569,6 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
     *s = 1.f/(*s);
 }
 
-//
-// logging
-//
-
-#if (GGML_DEBUG >= 1)
-#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
-#else
-#define GGML_PRINT_DEBUG(...)
-#endif
-
-#if (GGML_DEBUG >= 5)
-#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
-#else
-#define GGML_PRINT_DEBUG_5(...)
-#endif
-
-#if (GGML_DEBUG >= 10)
-#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
-#else
-#define GGML_PRINT_DEBUG_10(...)
-#endif
-
-#define GGML_PRINT(...) printf(__VA_ARGS__)
-
 //
 // data types
 //
@@ -3442,11 +3582,19 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q5_1] = QK5_1,
     [GGML_TYPE_Q8_0] = QK8_0,
     [GGML_TYPE_Q8_1] = QK8_1,
+#ifdef GGML_USE_K_QUANTS
+    [GGML_TYPE_Q2_K] = QK_K,
+    [GGML_TYPE_Q3_K] = QK_K,
+    [GGML_TYPE_Q4_K] = QK_K,
+    [GGML_TYPE_Q5_K] = QK_K,
+    [GGML_TYPE_Q6_K] = QK_K,
+    [GGML_TYPE_Q8_K] = QK_K,
+#endif
     [GGML_TYPE_I8]   = 1,
     [GGML_TYPE_I16]  = 1,
     [GGML_TYPE_I32]  = 1,
 };
-static_assert(GGML_TYPE_COUNT == 13, "GGML_BLCK_SIZE is outdated");
+static_assert(GGML_TYPE_COUNT == 19, "GGML_BLCK_SIZE is outdated");
 
 static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
     [GGML_TYPE_F32]  = sizeof(float),
@@ -3457,11 +3605,19 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q5_1] = sizeof(block_q5_1),
     [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
     [GGML_TYPE_Q8_1] = sizeof(block_q8_1),
+#ifdef GGML_USE_K_QUANTS
+    [GGML_TYPE_Q2_K] = sizeof(block_q2_K),
+    [GGML_TYPE_Q3_K] = sizeof(block_q3_K),
+    [GGML_TYPE_Q4_K] = sizeof(block_q4_K),
+    [GGML_TYPE_Q5_K] = sizeof(block_q5_K),
+    [GGML_TYPE_Q6_K] = sizeof(block_q6_K),
+    [GGML_TYPE_Q8_K] = sizeof(block_q8_K),
+#endif
     [GGML_TYPE_I8]   = sizeof(int8_t),
     [GGML_TYPE_I16]  = sizeof(int16_t),
     [GGML_TYPE_I32]  = sizeof(int32_t),
 };
-static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_SIZE is outdated");
+static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_SIZE is outdated");
 
 
 static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
@@ -3473,11 +3629,17 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q5_1] = "q5_1",
     [GGML_TYPE_Q8_0] = "q8_0",
     [GGML_TYPE_Q8_1] = "q8_1",
+    [GGML_TYPE_Q2_K] = "q2_K",
+    [GGML_TYPE_Q3_K] = "q3_K",
+    [GGML_TYPE_Q4_K] = "q4_K",
+    [GGML_TYPE_Q5_K] = "q5_K",
+    [GGML_TYPE_Q6_K] = "q6_K",
+    [GGML_TYPE_Q8_K] = "q8_K",
     [GGML_TYPE_I8]   = "i8",
     [GGML_TYPE_I16]  = "i16",
     [GGML_TYPE_I32]  = "i32",
 };
-static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_NAME is outdated");
+static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_NAME is outdated");
 
 static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
     [GGML_TYPE_F32]  = false,
@@ -3488,13 +3650,19 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q5_1] = true,
     [GGML_TYPE_Q8_0] = true,
     [GGML_TYPE_Q8_1] = true,
+    [GGML_TYPE_Q2_K] = true,
+    [GGML_TYPE_Q3_K] = true,
+    [GGML_TYPE_Q4_K] = true,
+    [GGML_TYPE_Q5_K] = true,
+    [GGML_TYPE_Q6_K] = true,
+    [GGML_TYPE_Q8_K] = true,
     [GGML_TYPE_I8]   = false,
     [GGML_TYPE_I16]  = false,
     [GGML_TYPE_I32]  = false,
 };
-static_assert(GGML_TYPE_COUNT == 13, "GGML_IS_QUANTIZED is outdated");
+static_assert(GGML_TYPE_COUNT == 19, "GGML_IS_QUANTIZED is outdated");
 
-static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
+static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "NONE",
 
     "DUP",
@@ -3511,12 +3679,14 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "SUM_ROWS",
     "MEAN",
     "REPEAT",
+    "REPEAT_BACK",
     "ABS",
     "SGN",
     "NEG",
     "STEP",
     "RELU",
     "GELU",
+    "GELU_QUICK",
     "SILU",
     "SILU_BACK",
     "NORM",
@@ -3524,6 +3694,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "RMS_NORM_BACK",
 
     "MUL_MAT",
+    "OUT_PROD",
 
     "SCALE",
     "SET",
@@ -3539,22 +3710,33 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "DIAG_MASK_INF",
     "DIAG_MASK_ZERO",
     "SOFT_MAX",
+    "SOFT_MAX_BACK",
     "ROPE",
     "ROPE_BACK",
     "ALIBI",
     "CLAMP",
-    "CONV_1D_1S",
-    "CONV_1D_2S",
+    "CONV_1D_S1_PH",
+    "CONV_1D_S2_PH",
+    "CONV_2D_SK_P0",
 
     "FLASH_ATTN",
     "FLASH_FF",
+    "FLASH_ATTN_BACK",
+    "WIN_PART",
+    "WIN_UNPART",
 
     "MAP_UNARY",
     "MAP_BINARY",
-};
 
-static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
+    "MAP_CUSTOM1",
+    "MAP_CUSTOM2",
+    "MAP_CUSTOM3",
+
+    "CROSS_ENTROPY_LOSS",
+    "CROSS_ENTROPY_LOSS_BACK",
+};
 
+static_assert(GGML_OP_COUNT == 64, "GGML_OP_COUNT != 64");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -3573,18 +3755,21 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "Σx_k",
     "Σx/n",
     "repeat(x)",
+    "repeat_back(x)",
     "abs(x)",
     "sgn(x)",
     "-x",
     "step(x)",
     "relu(x)",
     "gelu(x)",
+    "gelu_quick(x)",
     "silu(x)",
     "silu_back(x)",
     "norm(x)",
     "rms_norm(x)",
     "rms_norm_back(x)",
 
+    "X*Y",
     "X*Y",
 
     "x*v",
@@ -3601,21 +3786,33 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "diag_mask_inf(x)",
     "diag_mask_zero(x)",
     "soft_max(x)",
+    "soft_max_back(x)",
     "rope(x)",
     "rope_back(x)",
     "alibi(x)",
     "clamp(x)",
-    "conv_1d_1s(x)",
-    "conv_1d_2s(x)",
+    "conv_1d_s1_ph(x)",
+    "conv_1d_s2_ph(x)",
+    "conv_2d_sk_p0(x)",
 
     "flash_attn(x)",
     "flash_ff(x)",
+    "flash_attn_back(x)",
+    "win_part(x)",
+    "win_unpart(x)",
 
     "f(x)",
     "f(x,y)",
+
+    "custom(x)",
+    "custom(x,y)",
+    "custom(x,y,z)",
+
+    "cross_entropy_loss(x,y)",
+    "cross_entropy_loss_back(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
+static_assert(GGML_OP_COUNT == 64, "GGML_OP_COUNT != 64");
 
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
 static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -3629,6 +3826,7 @@ struct ggml_context {
     void * mem_buffer;
     bool   mem_buffer_owned;
     bool   no_alloc;
+    bool   no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
 
     int    n_objects;
 
@@ -3645,26 +3843,6 @@ struct ggml_context_container {
     struct ggml_context context;
 };
 
-//
-// compute types
-//
-
-enum ggml_task_type {
-    GGML_TASK_INIT = 0,
-    GGML_TASK_COMPUTE,
-    GGML_TASK_FINALIZE,
-};
-
-struct ggml_compute_params {
-    enum ggml_task_type type;
-
-    int ith, nth;
-
-    // work buffer for all threads
-    size_t wsize;
-    void * wdata;
-};
-
 //
 // ggml state
 //
@@ -3721,7 +3899,7 @@ int64_t ggml_nelements(const struct ggml_tensor * tensor) {
     return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
 }
 
-int ggml_nrows(const struct ggml_tensor * tensor) {
+int64_t ggml_nrows(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
     return tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
@@ -3730,7 +3908,20 @@ int ggml_nrows(const struct ggml_tensor * tensor) {
 size_t ggml_nbytes(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
-    return (ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type];
+    // this should handle cases where the tensor is not contiguous in memory
+    // probaby just:
+    //
+    //     return tensor->ne[3]*tensor->nb[3]
+    //
+    // is enough, but just in case, adding the second part
+
+    return MAX(tensor->ne[3]*tensor->nb[3], (ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type]);
+}
+
+size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return (nrows_split*tensor->ne[0]*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type];
 }
 
 int ggml_blck_size(enum ggml_type type) {
@@ -3749,6 +3940,9 @@ const char * ggml_type_name(enum ggml_type type) {
     return GGML_TYPE_NAME[type];
 }
 
+const char * ggml_op_name(enum ggml_op op) {
+    return GGML_OP_NAME[op];
+}
 
 size_t ggml_element_size(const struct ggml_tensor * tensor) {
     return GGML_TYPE_SIZE[tensor->type];
@@ -3781,6 +3975,15 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
         (t0->ne[3] == t1->ne[3]);
 }
 
+static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return
+        (t0->ne[1] == t1->ne[1])  &&
+        (t0->ne[2] == t1->ne[2])  &&
+        (t0->ne[3] == t1->ne[3]);
+}
+
 bool ggml_is_quantized(enum ggml_type type) {
     return GGML_IS_QUANTIZED[type];
 }
@@ -3796,6 +3999,11 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
         case GGML_FTYPE_MOSTLY_Q5_0:          wtype = GGML_TYPE_Q5_0;  break;
         case GGML_FTYPE_MOSTLY_Q5_1:          wtype = GGML_TYPE_Q5_1;  break;
         case GGML_FTYPE_MOSTLY_Q8_0:          wtype = GGML_TYPE_Q8_0;  break;
+        case GGML_FTYPE_MOSTLY_Q2_K:          wtype = GGML_TYPE_Q2_K;  break;
+        case GGML_FTYPE_MOSTLY_Q3_K:          wtype = GGML_TYPE_Q3_K;  break;
+        case GGML_FTYPE_MOSTLY_Q4_K:          wtype = GGML_TYPE_Q4_K;  break;
+        case GGML_FTYPE_MOSTLY_Q5_K:          wtype = GGML_TYPE_Q5_K;  break;
+        case GGML_FTYPE_MOSTLY_Q6_K:          wtype = GGML_TYPE_Q6_K;  break;
         case GGML_FTYPE_UNKNOWN:              wtype = GGML_TYPE_COUNT; break;
         case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
     }
@@ -3805,11 +4013,15 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
     return wtype;
 }
 
-static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
+size_t ggml_tensor_overhead(void) {
+    return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE + 16;
+}
+
+bool ggml_is_transposed(const struct ggml_tensor * tensor) {
     return tensor->nb[0] > tensor->nb[1];
 }
 
-static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
+bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
     return
@@ -3819,6 +4031,12 @@ static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
         tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
 }
 
+bool ggml_is_permuted(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
+}
+
 static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
@@ -3885,7 +4103,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
         // initialize time system (required on Windows)
         ggml_time_init();
 
-        // initialize GELU, SILU and EXP F32 tables
+        // initialize GELU, Quick GELU, SILU and EXP F32 tables
         {
             const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
 
@@ -3895,13 +4113,14 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
                 memcpy(&ii, &ui, sizeof(ii));
                 const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
                 table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
+                table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
                 table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
                 table_exp_f16[i]  = GGML_FP32_TO_FP16(expf(f));
             }
 
             const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
 
-            GGML_PRINT_DEBUG("%s: GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
+            GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
         }
 
         // initialize g_state
@@ -3958,6 +4177,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
         /*.mem_buffer         =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size),
         /*.mem_buffer_owned   =*/ params.mem_buffer ? false : true,
         /*.no_alloc           =*/ params.no_alloc,
+        /*.no_alloc_save      =*/ params.no_alloc,
         /*.n_objects          =*/ 0,
         /*.objects_begin      =*/ NULL,
         /*.objects_end        =*/ NULL,
@@ -4017,17 +4237,56 @@ size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch)
     return result;
 }
 
+void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc) {
+    ctx->no_alloc = no_alloc;
+}
+
+void * ggml_get_mem_buffer(const struct ggml_context * ctx) {
+    return ctx->mem_buffer;
+}
+
+size_t ggml_get_mem_size(const struct ggml_context * ctx) {
+    return ctx->mem_size;
+}
+
+size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) {
+    size_t max_size = 0;
+
+    struct ggml_object * obj = ctx->objects_begin;
+
+    while (obj != NULL) {
+        struct ggml_tensor * tensor = (struct ggml_tensor *) ((char *) ctx->mem_buffer + obj->offs);
+
+        const size_t size = ggml_nbytes(tensor);
+
+        if (max_size < size) {
+            max_size = size;
+        }
+
+        obj = obj->next;
+    }
+
+    return max_size;
+}
+
 // IMPORTANT:
 // when creating "opt" tensors, always save and load the scratch buffer
 // this is an error prone process, but it is necessary to support inplace
 // operators when using scratch buffers
 // TODO: implement a better way
 void ggml_scratch_save(struct ggml_context * ctx) {
+    // this is needed to allow opt tensors to store their data
+    // TODO: again, need to find a better way
+    ctx->no_alloc_save = ctx->no_alloc;
+    ctx->no_alloc      = false;
+
     ctx->scratch_save = ctx->scratch;
     ctx->scratch.data = NULL;
 }
 
 void ggml_scratch_load(struct ggml_context * ctx) {
+    ctx->no_alloc = ctx->no_alloc_save;
+
     ctx->scratch = ctx->scratch_save;
 }
 
@@ -4061,7 +4320,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
     struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end);
 
     if (ctx->scratch.data == NULL || data != NULL) {
-        size_needed += sizeof(struct ggml_tensor);
+        size_needed += GGML_TENSOR_SIZE;
 
         if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
             GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
@@ -4077,14 +4336,15 @@ struct ggml_tensor * ggml_new_tensor_impl(
         };
     } else {
         if (ctx->scratch.offs + size_needed > ctx->scratch.size) {
-            GGML_PRINT("%s: not enough space in the scratch memory\n", __func__);
+            GGML_PRINT("%s: not enough space in the scratch memory pool (needed %zu, available %zu)\n",
+                    __func__, ctx->scratch.offs + size_needed, ctx->scratch.size);
             assert(false);
             return NULL;
         }
 
-        if (cur_end + sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE > ctx->mem_size) {
+        if (cur_end + GGML_TENSOR_SIZE + GGML_OBJECT_SIZE > ctx->mem_size) {
             GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
-                    __func__, cur_end + sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE, ctx->mem_size);
+                    __func__, cur_end + GGML_TENSOR_SIZE + GGML_OBJECT_SIZE, ctx->mem_size);
             assert(false);
             return NULL;
         }
@@ -4093,7 +4353,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
 
         *obj_new = (struct ggml_object) {
             .offs = cur_end + GGML_OBJECT_SIZE,
-            .size = sizeof(struct ggml_tensor),
+            .size = GGML_TENSOR_SIZE,
             .next = NULL,
         };
 
@@ -4135,6 +4395,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
         /*.perf_time_us =*/ 0,
         /*.data         =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
         /*.name         =*/ { 0 },
+        /*.extra        =*/ NULL,
         /*.pad          =*/ { 0 },
     };
 
@@ -4491,15 +4752,25 @@ const char * ggml_get_name(const struct ggml_tensor * tensor) {
     return tensor->name;
 }
 
-void ggml_set_name(struct ggml_tensor * tensor, const char * name) {
+struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name) {
     strncpy(tensor->name, name, sizeof(tensor->name));
     tensor->name[sizeof(tensor->name) - 1] = '\0';
+    return tensor;
+}
+
+struct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char * fmt, ...) {
+    va_list args;
+    va_start(args, fmt);
+    vsnprintf(tensor->name, sizeof(tensor->name), fmt, args);
+    va_end(args);
+    return tensor;
 }
 
 struct ggml_tensor * ggml_view_tensor(
         struct ggml_context * ctx,
         const struct ggml_tensor * src) {
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data);
+    ggml_format_name(result, "%s (view)", src->name);
 
     result->nb[0] = src->nb[0];
     result->nb[1] = src->nb[1];
@@ -4509,6 +4780,23 @@ struct ggml_tensor * ggml_view_tensor(
     return result;
 }
 
+struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name) {
+    struct ggml_object * obj = ctx->objects_begin;
+
+    char * const mem_buffer = ctx->mem_buffer;
+
+    while (obj != NULL) {
+        struct ggml_tensor * cur = (struct ggml_tensor *)(mem_buffer + obj->offs);
+        if (strcmp(cur->name, name) == 0) {
+            return cur;
+        }
+
+        obj = obj->next;
+    }
+
+    return NULL;
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 
 // ggml_dup
@@ -4556,7 +4844,7 @@ struct ggml_tensor * ggml_add_impl(
 
     bool is_node = false;
 
-    if (!inplace && (a->grad || b->grad)) {
+    if (a->grad || b->grad) {
         is_node = true;
     }
 
@@ -4596,7 +4884,7 @@ struct ggml_tensor * ggml_add1_impl(
 
     bool is_node = false;
 
-    if (!inplace && (a->grad || b->grad)) {
+    if (a->grad || b->grad) {
         is_node = true;
     }
 
@@ -5022,6 +5310,34 @@ struct ggml_tensor * ggml_repeat(
     return result;
 }
 
+// ggml_repeat_back
+
+struct ggml_tensor * ggml_repeat_back(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    GGML_ASSERT(ggml_can_repeat(b, a));
+
+    bool is_node = false;
+
+    if (a->grad) {
+        is_node = true;
+    }
+
+    if (ggml_are_same_shape(a, b) && !is_node) {
+        return a;
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
+
+    result->op   = GGML_OP_REPEAT_BACK;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
 // ggml_abs
 
 struct ggml_tensor * ggml_abs_impl(
@@ -5227,6 +5543,40 @@ struct ggml_tensor * ggml_gelu_inplace(
     return ggml_gelu_impl(ctx, a, true);
 }
 
+// ggml_gelu_quick
+
+struct ggml_tensor * ggml_gelu_quick_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_GELU_QUICK;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_gelu_quick(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_gelu_quick_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_gelu_quick_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_gelu_quick_impl(ctx, a, true);
+}
+
 // ggml_silu
 
 struct ggml_tensor * ggml_silu_impl(
@@ -5399,6 +5749,32 @@ struct ggml_tensor * ggml_mul_mat(
     return result;
 }
 
+// ggml_out_prod
+
+struct ggml_tensor * ggml_out_prod(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    GGML_ASSERT(ggml_can_out_prod(a, b));
+    GGML_ASSERT(!ggml_is_transposed(a));
+
+    bool is_node = false;
+
+    if (a->grad || b->grad) {
+        is_node = true;
+    }
+
+    const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
+
+    result->op   = GGML_OP_OUT_PROD;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
 // ggml_scale
 
 struct ggml_tensor * ggml_scale_impl(
@@ -5411,7 +5787,7 @@ struct ggml_tensor * ggml_scale_impl(
 
     bool is_node = false;
 
-    if (!inplace && (a->grad || b->grad)) {
+    if (a->grad || b->grad) {
         is_node = true;
     }
 
@@ -5454,7 +5830,7 @@ struct ggml_tensor * ggml_set_impl(
 
     bool is_node = false;
 
-    if (!inplace && (a->grad || b->grad)) {
+    if (a->grad || b->grad) {
         is_node = true;
     }
 
@@ -5556,6 +5932,11 @@ struct ggml_tensor * ggml_cpy_impl(
 
     // make a view of the destination
     struct ggml_tensor * result = ggml_view_tensor(ctx, b);
+    if (strlen(b->name) > 0) {
+        ggml_format_name(result, "%s (copy of %s)", b->name, a->name);
+    } else {
+        ggml_format_name(result, "%s (copy)", a->name);
+    }
 
     result->op   = GGML_OP_CPY;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5592,6 +5973,7 @@ struct ggml_tensor * ggml_cont_impl(
     }
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+    ggml_format_name(result, "%s (cont)", a->name);
 
     result->op   = GGML_OP_CONT;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5635,6 +6017,7 @@ struct ggml_tensor * ggml_reshape(
     }
 
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a->data);
+    ggml_format_name(result, "%s (reshaped)", a->name);
 
     result->op   = GGML_OP_RESHAPE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5659,6 +6042,7 @@ struct ggml_tensor * ggml_reshape_1d(
 
     const int64_t ne[1] = { ne0 };
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, ne, a->data);
+    ggml_format_name(result, "%s (reshaped)", a->name);
 
     result->op   = GGML_OP_RESHAPE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5684,6 +6068,7 @@ struct ggml_tensor * ggml_reshape_2d(
 
     const int64_t ne[2] = { ne0, ne1 };
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a->data);
+    ggml_format_name(result, "%s (reshaped)", a->name);
 
     result->op   = GGML_OP_RESHAPE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5710,6 +6095,7 @@ struct ggml_tensor * ggml_reshape_3d(
 
     const int64_t ne[3] = { ne0, ne1, ne2 };
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a->data);
+    ggml_format_name(result, "%s (reshaped)", a->name);
 
     result->op   = GGML_OP_RESHAPE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5738,6 +6124,7 @@ struct ggml_tensor * ggml_reshape_4d(
 
     const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a->data);
+    ggml_format_name(result, "%s (reshaped)", a->name);
 
     result->op   = GGML_OP_RESHAPE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5762,15 +6149,21 @@ struct ggml_tensor * ggml_view_1d(
     }
 
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset);
+    ggml_format_name(result, "%s (view)", a->name);
+
+    ggml_scratch_save(ctx);
+
+    struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
+    ggml_set_name(offs, "offset");
+    memcpy(offs->data, &offset, 2*sizeof(int32_t));
+
+    ggml_scratch_load(ctx);
 
     result->op   = GGML_OP_VIEW;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
     result->src1 = NULL;
-
-    if (is_node) {
-        memcpy(result->padding, &offset, sizeof(offset));
-    }
+    result->opt[0] = offs;
 
     return result;
 }
@@ -5794,8 +6187,17 @@ struct ggml_tensor * ggml_view_2d(
     const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, 1, 1 };
 
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset);
+    ggml_format_name(result, "%s (view)", a->name);
 
-    result->nb[1] = nb1;
+    ggml_scratch_save(ctx);
+
+    struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
+    ggml_set_name(offs, "offset");
+    memcpy(offs->data, &offset, 2*sizeof(int32_t));
+
+    ggml_scratch_load(ctx);
+
+    result->nb[1] = nb1;
     result->nb[2] = result->nb[1]*ne1;
     result->nb[3] = result->nb[2];
 
@@ -5803,10 +6205,7 @@ struct ggml_tensor * ggml_view_2d(
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
     result->src1 = NULL;
-
-    if (is_node) {
-        memcpy(result->padding, &offset, sizeof(offset));
-    }
+    result->opt[0] = offs;
 
     return result;
 }
@@ -5832,6 +6231,15 @@ struct ggml_tensor * ggml_view_3d(
     const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, ne2, 1 };
 
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, (char *) a->data + offset);
+    ggml_format_name(result, "%s (view)", a->name);
+
+    ggml_scratch_save(ctx);
+
+    struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
+    ggml_set_name(offs, "offset");
+    memcpy(offs->data, &offset, 2*sizeof(int32_t));
+
+    ggml_scratch_load(ctx);
 
     result->nb[1] = nb1;
     result->nb[2] = nb2;
@@ -5841,10 +6249,7 @@ struct ggml_tensor * ggml_view_3d(
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
     result->src1 = NULL;
-
-    if (is_node) {
-        memcpy(result->padding, &offset, sizeof(offset));
-    }
+    result->opt[0] = offs;
 
     return result;
 }
@@ -5872,6 +6277,15 @@ struct ggml_tensor * ggml_view_4d(
     const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, ne2, ne3 };
 
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, (char *) a->data + offset);
+    ggml_format_name(result, "%s (view)", a->name);
+
+    ggml_scratch_save(ctx);
+
+    struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
+    ggml_set_name(offs, "offset");
+    memcpy(offs->data, &offset, 2*sizeof(int32_t));
+
+    ggml_scratch_load(ctx);
 
     result->nb[1] = nb1;
     result->nb[2] = nb2;
@@ -5881,10 +6295,7 @@ struct ggml_tensor * ggml_view_4d(
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
     result->src1 = NULL;
-
-    if (is_node) {
-        memcpy(result->padding, &offset, sizeof(offset));
-    }
+    result->opt[0] = offs;
 
     return result;
 }
@@ -5917,6 +6328,7 @@ struct ggml_tensor * ggml_permute(
     }
 
     struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+    ggml_format_name(result, "%s (permuted)", a->name);
 
     int ne[GGML_MAX_DIMS];
     int nb[GGML_MAX_DIMS];
@@ -5947,10 +6359,18 @@ struct ggml_tensor * ggml_permute(
     result->src1 = NULL;
 
     if (is_node) {
-        result->padding[0] = axis0;
-        result->padding[1] = axis1;
-        result->padding[2] = axis2;
-        result->padding[3] = axis3;
+        ggml_scratch_save(ctx);
+
+        struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
+
+        ((int32_t *) b->data)[0] = axis0;
+        ((int32_t *) b->data)[1] = axis1;
+        ((int32_t *) b->data)[2] = axis2;
+        ((int32_t *) b->data)[3] = axis3;
+
+        ggml_scratch_load(ctx);
+
+        result->opt[0] = b;
     }
 
     return result;
@@ -5968,6 +6388,7 @@ struct ggml_tensor * ggml_transpose(
     }
 
     struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+    ggml_format_name(result, "%s (transposed)", a->name);
 
     result->ne[0] = a->ne[1];
     result->ne[1] = a->ne[0];
@@ -6190,6 +6611,44 @@ struct ggml_tensor * ggml_soft_max_inplace(
     return ggml_soft_max_impl(ctx, a, true);
 }
 
+
+// ggml_soft_max_back
+
+struct ggml_tensor * ggml_soft_max_back_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        bool                  inplace) {
+    bool is_node = false;
+
+    if (a->grad || b->grad) {
+        is_node = true; // TODO : implement backward pass
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_SOFT_MAX_BACK;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_soft_max_back(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_soft_max_back_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_soft_max_back_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_soft_max_back_impl(ctx, a, b, true);
+}
+
 // ggml_rope
 
 struct ggml_tensor * ggml_rope_impl(
@@ -6202,7 +6661,7 @@ struct ggml_tensor * ggml_rope_impl(
     GGML_ASSERT(n_past >= 0);
     bool is_node = false;
 
-    if (!inplace && a->grad) {
+    if (a->grad) {
         is_node = true;
     }
 
@@ -6256,8 +6715,7 @@ struct ggml_tensor * ggml_rope_back(
     bool is_node = false;
 
     if (a->grad) {
-        GGML_ASSERT(false); // TODO: implement backward
-        is_node = true;
+        is_node = false; // TODO: implement backward
     }
 
     struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
@@ -6303,7 +6761,7 @@ struct ggml_tensor * ggml_alibi(
 
     ggml_scratch_save(ctx);
 
-    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
+    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
 
     ((int32_t *) b->data)[0] = n_past;
     ((int32_t *) b->data)[1] = n_head;
@@ -6339,7 +6797,7 @@ struct ggml_tensor * ggml_clamp(
 
     ggml_scratch_save(ctx);
 
-    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
+    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 2);
 
     ((float *) b->data)[0] = min;
     ((float *) b->data)[1] = max;
@@ -6354,9 +6812,9 @@ struct ggml_tensor * ggml_clamp(
     return result;
 }
 
-// ggml_conv_1d_1s
+// ggml_conv_1d_s1_ph
 
-struct ggml_tensor * ggml_conv_1d_1s(
+struct ggml_tensor * ggml_conv_1d_s1_ph(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b) {
@@ -6373,7 +6831,7 @@ struct ggml_tensor * ggml_conv_1d_1s(
     const int64_t ne[4] = { b->ne[0], a->ne[2], 1, 1, };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
 
-    result->op   = GGML_OP_CONV_1D_1S;
+    result->op   = GGML_OP_CONV_1D_S1_PH;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
     result->src1 = b;
@@ -6381,9 +6839,9 @@ struct ggml_tensor * ggml_conv_1d_1s(
     return result;
 }
 
-// ggml_conv_1d_2s
+// ggml_conv_1d_s2_ph
 
-struct ggml_tensor * ggml_conv_1d_2s(
+struct ggml_tensor * ggml_conv_1d_s2_ph(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b) {
@@ -6400,7 +6858,35 @@ struct ggml_tensor * ggml_conv_1d_2s(
     const int64_t ne[4] = { b->ne[0]/2, a->ne[2], 1, 1, };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
 
-    result->op   = GGML_OP_CONV_1D_2S;
+    result->op   = GGML_OP_CONV_1D_S2_PH;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+// ggml_conv_2d_sk_p0
+
+struct ggml_tensor * ggml_conv_2d_sk_p0(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    GGML_ASSERT(b->ne[3] == 1);
+    GGML_ASSERT(a->ne[2] == b->ne[2]);
+    GGML_ASSERT(b->ne[0] %  a->ne[0] == 0);
+    GGML_ASSERT(b->ne[1] %  a->ne[1] == 0);
+    bool is_node = false;
+
+    if (a->grad || b->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    const int64_t ne[4] = { b->ne[0]/a->ne[0], b->ne[1]/a->ne[1], a->ne[3], 1, };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    result->op   = GGML_OP_CONV_2D_SK_P0;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
     result->src1 = b;
@@ -6422,7 +6908,6 @@ struct ggml_tensor * ggml_flash_attn(
     bool is_node = false;
 
     if (q->grad || k->grad || v->grad) {
-        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -6454,7 +6939,6 @@ struct ggml_tensor * ggml_flash_ff(
     bool is_node = false;
 
     if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
-        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -6472,6 +6956,154 @@ struct ggml_tensor * ggml_flash_ff(
     return result;
 }
 
+// ggml_flash_attn_back
+
+struct ggml_tensor * ggml_flash_attn_back(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * q,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * d,
+        bool                  masked) {
+    GGML_ASSERT(ggml_can_mul_mat(k, q));
+    // TODO: check if vT can be multiplied by (k*qT)
+
+    // d shape [D,N,ne2,ne3]
+    // q shape [D,N,ne2,ne3]
+    // k shape [D,M,ne2,ne3]
+    // v shape [M,D,ne2,ne3]
+
+    const int64_t   D = q->ne[0];
+    const int64_t   N = q->ne[1];
+    const int64_t   M = k->ne[1];
+    const int64_t ne2 = q->ne[2];
+    const int64_t ne3 = q->ne[3];
+
+    GGML_ASSERT(k->ne[0] == D);
+    GGML_ASSERT(v->ne[0] == M);
+    GGML_ASSERT(v->ne[1] == D);
+    GGML_ASSERT(d->ne[0] == D);
+    GGML_ASSERT(d->ne[1] == N);
+    GGML_ASSERT(k->ne[2] == ne2);
+    GGML_ASSERT(k->ne[3] == ne3);
+    GGML_ASSERT(v->ne[2] == ne2);
+    GGML_ASSERT(v->ne[3] == ne3);
+    GGML_ASSERT(d->ne[2] == ne2);
+    GGML_ASSERT(d->ne[3] == ne3);
+
+    bool is_node = false;
+
+    if (q->grad || k->grad || v->grad) {
+        // when using this operation (in backwards pass) these grads are set.
+        // we don't want to create (big) grad of our result, so is_node is false.
+        is_node = false;
+    }
+
+    // store gradients of q, k and v as continuous tensors concatenated in result.
+    // q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3]
+    // gradq->data = result->data
+    // gradk->data = result->data + nb0*D*N*ne2*ne3
+    // gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3
+    // note: v and gradv are actually transposed, i.e. v->ne[0] != D.
+    int64_t ne[4] = {D,M+N+M,ne2,ne3};
+
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    result->op   = GGML_OP_FLASH_ATTN_BACK;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = q;
+    result->src1 = k;
+    result->opt[0] = v;
+    result->opt[1] = d;
+    result->opt[2] = ggml_new_i32(ctx, masked ? 1 : 0);
+
+    return result;
+}
+
+// ggml_win_part
+
+struct ggml_tensor * ggml_win_part(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   w) {
+    GGML_ASSERT(a->ne[3] == 1);
+    GGML_ASSERT(a->type  == GGML_TYPE_F32);
+
+    bool is_node = false;
+
+    if (a->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    // padding
+    const int px = (w - a->ne[1]%w)%w;
+    const int py = (w - a->ne[2]%w)%w;
+
+    const int npx = (px + a->ne[1])/w;
+    const int npy = (py + a->ne[2])/w;
+    const int np  = npx*npy;
+
+    const int64_t ne[4] = { a->ne[0], w, w, np, };
+
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    ggml_scratch_save(ctx);
+
+    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
+
+    ((int32_t *) b->data)[0] = npx;
+    ((int32_t *) b->data)[1] = npy;
+    ((int32_t *) b->data)[2] = w;
+
+    ggml_scratch_load(ctx);
+
+    result->op   = GGML_OP_WIN_PART;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+    result->opt[0] = b;
+
+    return result;
+}
+
+// ggml_win_unpart
+
+struct ggml_tensor * ggml_win_unpart(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   w0,
+        int                   h0,
+        int                   w) {
+    GGML_ASSERT(a->type == GGML_TYPE_F32);
+
+    bool is_node = false;
+
+    if (a->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    const int64_t ne[4] = { a->ne[0], w0, h0, 1, };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
+
+    ggml_scratch_save(ctx);
+
+    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
+
+    ((int32_t *) b->data)[0] = w;
+
+    ggml_scratch_load(ctx);
+
+    result->op   = GGML_OP_WIN_UNPART;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+    result->opt[0] = b;
+
+    return result;
+}
+
 // ggml_map_unary
 
 struct ggml_tensor * ggml_map_unary_impl_f32(
@@ -6485,9 +7117,14 @@ struct ggml_tensor * ggml_map_unary_impl_f32(
         is_node = true;
     }
 
+    struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_scratch_save(ctx);
+
     struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
     *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
-    struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_scratch_load(ctx);
 
     result->op = GGML_OP_MAP_UNARY;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6527,9 +7164,14 @@ struct ggml_tensor * ggml_map_binary_impl_f32(
         is_node = true;
     }
 
+    struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_scratch_save(ctx);
+
     struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
     *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
-    struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_scratch_load(ctx);
 
     result->op = GGML_OP_MAP_BINARY;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6556,72 +7198,260 @@ struct ggml_tensor * ggml_map_binary_inplace_f32(
     return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
 }
 
-////////////////////////////////////////////////////////////////////////////////
+// ggml_map_custom1
 
-void ggml_set_param(
-        struct ggml_context * ctx,
-        struct ggml_tensor * tensor) {
-    tensor->is_param = true;
+struct ggml_tensor * ggml_map_custom1_impl_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        const  ggml_custom1_op_f32_t   fun,
+        bool   inplace) {
+    bool is_node = false;
 
-    GGML_ASSERT(tensor->grad == NULL);
-    tensor->grad = ggml_dup_tensor(ctx, tensor);
-}
+    if (!inplace && a->grad) {
+        is_node = true;
+    }
 
-// ggml_compute_forward_dup
+    struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-static void ggml_compute_forward_dup_same_cont(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        struct ggml_tensor * dst) {
-    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
-    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
-    GGML_ASSERT(src0->type == dst->type);
+    ggml_scratch_save(ctx);
 
-    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
-        return;
-    }
+    struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
+    *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
 
-    const size_t nb00 = src0->nb[0];
-    const size_t nb0 = dst->nb[0];
+    ggml_scratch_load(ctx);
 
-    const int ith = params->ith; // thread index
-    const int nth = params->nth; // number of threads
+    result->op = GGML_OP_MAP_CUSTOM1;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->opt[0] = addr_tensor;
 
-    // parallelize by elements
-    const int ne = ggml_nelements(dst);
-    const int dr = (ne + nth - 1) / nth;
-    const int ie0 = dr * ith;
-    const int ie1 = MIN(ie0 + dr, ne);
+    return result;
+}
 
-    if (ie0 < ie1) {
-        memcpy(
-            ((char *)  dst->data + ie0*nb0),
-            ((char *) src0->data + ie0*nb00),
-            (ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
-    }
+struct ggml_tensor * ggml_map_custom1_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        const  ggml_custom1_op_f32_t   fun) {
+    return ggml_map_custom1_impl_f32(ctx, a, fun, false);
+}
 
+struct ggml_tensor * ggml_map_custom1_inplace_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        const  ggml_custom1_op_f32_t   fun) {
+    return ggml_map_custom1_impl_f32(ctx, a, fun, true);
 }
-static void ggml_compute_forward_dup_f16(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        struct ggml_tensor * dst) {
-    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
 
-    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
-        return;
+// ggml_map_custom2
+
+struct ggml_tensor * ggml_map_custom2_impl_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        struct ggml_tensor           * b,
+        const  ggml_custom2_op_f32_t   fun,
+        bool   inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad || b->grad)) {
+        is_node = true;
     }
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
+    struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    const int64_t ne0 = dst->ne[0];
-    const int64_t ne1 = dst->ne[1];
-    const int64_t ne2 = dst->ne[2];
-    const int64_t ne3 = dst->ne[3];
+    ggml_scratch_save(ctx);
 
-    const size_t nb00 = src0->nb[0];
+    struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
+    *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
+
+    ggml_scratch_load(ctx);
+
+    result->op = GGML_OP_MAP_CUSTOM2;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+    result->opt[0] = addr_tensor;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_map_custom2_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        struct ggml_tensor           * b,
+        const  ggml_custom2_op_f32_t   fun) {
+    return ggml_map_custom2_impl_f32(ctx, a, b, fun, false);
+}
+
+struct ggml_tensor * ggml_map_custom2_inplace_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        struct ggml_tensor           * b,
+        const  ggml_custom2_op_f32_t   fun) {
+    return ggml_map_custom2_impl_f32(ctx, a, b, fun, true);
+}
+
+// ggml_map_custom3
+
+struct ggml_tensor * ggml_map_custom3_impl_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        struct ggml_tensor           * b,
+        struct ggml_tensor           * c,
+        const  ggml_custom3_op_f32_t   fun,
+        bool   inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad || b->grad || c->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_scratch_save(ctx);
+
+    struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
+    *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
+
+    ggml_scratch_load(ctx);
+
+    result->op = GGML_OP_MAP_CUSTOM3;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+    result->opt[0] = addr_tensor;
+    result->opt[1] = c;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_map_custom3_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        struct ggml_tensor           * b,
+        struct ggml_tensor           * c,
+        const  ggml_custom3_op_f32_t   fun) {
+    return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, false);
+}
+
+struct ggml_tensor * ggml_map_custom3_inplace_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        struct ggml_tensor           * b,
+        struct ggml_tensor           * c,
+        const  ggml_custom3_op_f32_t   fun) {
+    return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, true);
+}
+
+// ggml_cross_entropy_loss
+
+struct ggml_tensor * ggml_cross_entropy_loss(
+        struct ggml_context         * ctx,
+        struct ggml_tensor          * a,
+        struct ggml_tensor          * b) {
+    GGML_ASSERT(ggml_are_same_shape(a, b));
+    bool is_node = false;
+
+    if (a->grad || b->grad) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
+
+    result->op   = GGML_OP_CROSS_ENTROPY_LOSS;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+// ggml_cross_entropy_loss_back
+
+struct ggml_tensor * ggml_cross_entropy_loss_back(
+        struct ggml_context         * ctx,
+        struct ggml_tensor          * a,
+        struct ggml_tensor          * b,
+        struct ggml_tensor          * c) {
+    GGML_ASSERT(ggml_are_same_shape(a, b));
+    GGML_ASSERT(ggml_is_scalar(c));
+
+    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
+    result->grad = NULL;
+    result->src0 = a;
+    result->src1 = b;
+    result->opt[0] = c;
+
+    return result;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+void ggml_set_param(
+        struct ggml_context * ctx,
+        struct ggml_tensor * tensor) {
+    tensor->is_param = true;
+
+    GGML_ASSERT(tensor->grad == NULL);
+    tensor->grad = ggml_dup_tensor(ctx, tensor);
+}
+
+// ggml_compute_forward_dup
+
+static void ggml_compute_forward_dup_same_cont(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+    GGML_ASSERT(src0->type == dst->type);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb0 = dst->nb[0];
+
+    const int ith = params->ith; // thread index
+    const int nth = params->nth; // number of threads
+
+    // parallelize by elements
+    const int ne = ggml_nelements(dst);
+    const int dr = (ne + nth - 1) / nth;
+    const int ie0 = dr * ith;
+    const int ie1 = MIN(ie0 + dr, ne);
+
+    if (ie0 < ie1) {
+        memcpy(
+            ((char *)  dst->data + ie0*nb0),
+            ((char *) src0->data + ie0*nb00),
+            (ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
+    }
+
+}
+static void ggml_compute_forward_dup_f16(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    const int64_t ne3 = dst->ne[3];
+
+    const size_t nb00 = src0->nb[0];
     const size_t nb01 = src0->nb[1];
     const size_t nb02 = src0->nb[2];
     const size_t nb03 = src0->nb[3];
@@ -7505,7 +8335,7 @@ static void ggml_compute_forward_add_q_f32(
 
         void  * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
         float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
-        void  * dst_row  = (void *) ((char *)  dst->data + ( i1*nb1  +  i2*nb2  +  i3*nb0));
+        void  * dst_row  = (void *) ((char *)  dst->data + ( i1*nb1  +  i2*nb2  +  i3*nb3));
 
         assert(ne00 % 32 == 0);
 
@@ -7545,6 +8375,11 @@ static void ggml_compute_forward_add(
         case GGML_TYPE_Q5_0:
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
             {
                 ggml_compute_forward_add_q_f32(params, src0, src1, dst);
             } break;
@@ -7848,6 +8683,11 @@ static void ggml_compute_forward_add1(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
             {
                 ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
             } break;
@@ -7970,6 +8810,11 @@ static void ggml_compute_forward_acc(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
         default:
             {
                 GGML_ASSERT(false);
@@ -8088,10 +8933,10 @@ static void ggml_compute_forward_mul_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-#ifdef GGML_USE_CUBLAS
-    if (src1->backend == GGML_BACKEND_CUDA) {
+#ifdef GGML_USE_CLBLAST
+    if (src1->backend == GGML_BACKEND_GPU) {
         if (ith == 0) {
-            ggml_cuda_mul(src0, src1, dst);
+            ggml_cl_mul(src0, src1, dst);
         }
         return;
     }
@@ -8691,6 +9536,99 @@ static void ggml_compute_forward_repeat(
     }
 }
 
+// ggml_compute_forward_repeat_back
+
+static void ggml_compute_forward_repeat_back_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(params->ith == 0);
+    GGML_ASSERT(ggml_can_repeat(dst, src0));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int64_t ne0  = dst->ne[0];
+    const int64_t ne1  = dst->ne[1];
+    const int64_t ne2  = dst->ne[2];
+    const int64_t ne3  = dst->ne[3];
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
+
+    const size_t nb0  = dst->nb[0];
+    const size_t nb1  = dst->nb[1];
+    const size_t nb2  = dst->nb[2];
+    const size_t nb3  = dst->nb[3];
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    // guaranteed to be an integer due to the check in ggml_can_repeat
+    const int nr0 = (int)(ne00/ne0);
+    const int nr1 = (int)(ne01/ne1);
+    const int nr2 = (int)(ne02/ne2);
+    const int nr3 = (int)(ne03/ne3);
+
+    // TODO: support for transposed / permuted tensors
+    GGML_ASSERT(nb0  == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    if (ggml_is_contiguous(dst)) {
+        ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
+    } else {
+        for         (int k3 = 0; k3 < ne3; k3++) {
+            for     (int k2 = 0; k2 < ne2; k2++) {
+                for (int k1 = 0; k1 < ne1; k1++) {
+                    ggml_vec_set_f32(ne0,
+                        (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
+                        0);
+                }
+            }
+        }
+    }
+
+    // TODO: maybe this is not optimal?
+    for                         (int i3 = 0; i3 < nr3; i3++) {
+        for                     (int k3 = 0; k3 < ne3; k3++) {
+            for                 (int i2 = 0; i2 < nr2; i2++) {
+                for             (int k2 = 0; k2 < ne2; k2++) {
+                    for         (int i1 = 0; i1 < nr1; i1++) {
+                        for     (int k1 = 0; k1 < ne1; k1++) {
+                            for (int i0 = 0; i0 < nr0; i0++) {
+                                ggml_vec_acc_f32(ne0,
+                                        (float *) ((char *)  dst->data + (         k3)*nb3  + (         k2)*nb2  + (         k1)*nb1),
+                                        (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_repeat_back(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_repeat_back_f32(params, src0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_abs
 
 static void ggml_compute_forward_abs_f32(
@@ -8958,13 +9896,11 @@ static void ggml_compute_forward_gelu(
                 GGML_ASSERT(false);
             } break;
     }
-
-    //printf("XXXXXXXX gelu\n");
 }
 
-// ggml_compute_forward_silu
+// ggml_compute_forward_gelu_quick
 
-static void ggml_compute_forward_silu_f32(
+static void ggml_compute_forward_gelu_quick_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -8990,7 +9926,7 @@ static void ggml_compute_forward_silu_f32(
     const int ir1 = MIN(ir0 + dr, nr);
 
     for (int i1 = ir0; i1 < ir1; i1++) {
-        ggml_vec_silu_f32(nc,
+        ggml_vec_gelu_quick_f32(nc,
                 (float *) ((char *) dst->data  + i1*( dst->nb[1])),
                 (float *) ((char *) src0->data + i1*(src0->nb[1])));
 
@@ -9005,14 +9941,14 @@ static void ggml_compute_forward_silu_f32(
     }
 }
 
-static void ggml_compute_forward_silu(
+static void ggml_compute_forward_gelu_quick(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_silu_f32(params, src0, dst);
+                ggml_compute_forward_gelu_quick_f32(params, src0, dst);
             } break;
         default:
             {
@@ -9021,19 +9957,15 @@ static void ggml_compute_forward_silu(
     }
 }
 
+// ggml_compute_forward_silu
 
-// ggml_compute_forward_silu_back
-
-static void ggml_compute_forward_silu_back_f32(
+static void ggml_compute_forward_silu_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * grad,
         struct ggml_tensor * dst) {
-    GGML_ASSERT(ggml_is_contiguous(grad));
     GGML_ASSERT(ggml_is_contiguous(src0));
     GGML_ASSERT(ggml_is_contiguous(dst));
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
-    GGML_ASSERT(ggml_are_same_shape(src0, grad));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -9053,10 +9985,9 @@ static void ggml_compute_forward_silu_back_f32(
     const int ir1 = MIN(ir0 + dr, nr);
 
     for (int i1 = ir0; i1 < ir1; i1++) {
-        ggml_vec_silu_backward_f32(nc,
+        ggml_vec_silu_f32(nc,
                 (float *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (float *) ((char *) src0->data + i1*(src0->nb[1])),
-                (float *) ((char *) grad->data + i1*(grad->nb[1])));
+                (float *) ((char *) src0->data + i1*(src0->nb[1])));
 
 #ifndef NDEBUG
         for (int k = 0; k < nc; k++) {
@@ -9069,15 +10000,14 @@ static void ggml_compute_forward_silu_back_f32(
     }
 }
 
-static void ggml_compute_forward_silu_back(
+static void ggml_compute_forward_silu(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * grad,
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_silu_back_f32(params, src0, grad, dst);
+                ggml_compute_forward_silu_f32(params, src0, dst);
             } break;
         default:
             {
@@ -9086,29 +10016,94 @@ static void ggml_compute_forward_silu_back(
     }
 }
 
-// ggml_compute_forward_norm
 
-static void ggml_compute_forward_norm_f32(
+// ggml_compute_forward_silu_back
+
+static void ggml_compute_forward_silu_back_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
+        const struct ggml_tensor * grad,
         struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(grad));
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, grad));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
 
-    const size_t nb01 = src0->nb[1];
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_silu_backward_f32(nc,
+                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (float *) ((char *) src0->data + i1*(src0->nb[1])),
+                (float *) ((char *) grad->data + i1*(grad->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_silu_back(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * grad,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_silu_back_f32(params, src0, grad, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_norm
+
+static void ggml_compute_forward_norm_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
+
+    const size_t nb01 = src0->nb[1];
     const size_t nb02 = src0->nb[2];
     const size_t nb03 = src0->nb[3];
 
@@ -9206,7 +10201,7 @@ static void ggml_compute_forward_rms_norm_f32(
                     sum += (ggml_float)(x[i00] * x[i00]);
                 }
 
-                float mean = sum/ne00;
+                const float mean = sum/ne00;
 
                 float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
 
@@ -9431,7 +10426,7 @@ static void ggml_compute_forward_rms_norm_back(
 
 // ggml_compute_forward_mul_mat
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
 // helper function to determine if it is better to use BLAS or not
 // for large matrices, BLAS is faster
 static bool ggml_compute_forward_mul_mat_use_blas(
@@ -9472,7 +10467,7 @@ static void ggml_compute_forward_mul_mat_f32(
     const int64_t ne02 = src0->ne[2];
     const int64_t ne03 = src0->ne[3];
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
     const int64_t ne10 = src1->ne[0];
 #endif
     const int64_t ne11 = src1->ne[1];
@@ -9529,16 +10524,16 @@ static void ggml_compute_forward_mul_mat_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_CUBLAS)
-    if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
+#if defined(GGML_USE_CLBLAST)
+    if (ggml_cl_can_mul_mat(src0, src1, dst)) {
         if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
-            ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
+            ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
         }
         return;
     }
 #endif
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
         if (params->ith != 0) {
             return;
@@ -9558,21 +10553,11 @@ static void ggml_compute_forward_mul_mat_f32(
                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
-#if defined(GGML_USE_CLBLAST)
-                // zT = y * xT
-                ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
-                        ne11, ne01, ne10,
-                        1.0f,    y, ne10,
-                                 x, ne10,
-                        0.0f,    d, ne01,
-                        GGML_TYPE_F32);
-#else
                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
                         ne11, ne01, ne10,
                         1.0f,    y, ne10,
                                  x, ne00,
                         0.0f,    d, ne01);
-#endif
             }
         }
         //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
@@ -9704,16 +10689,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_CUBLAS)
-    if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
+#if defined(GGML_USE_CLBLAST)
+    if (ggml_cl_can_mul_mat(src0, src1, dst)) {
         if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
-            ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
+            ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
         }
         return;
     }
 #endif
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
         GGML_ASSERT(nb10 == sizeof(float));
 
@@ -9743,20 +10728,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
                     assert(id*sizeof(float) <= params->wsize);
                 }
 
-#if defined(GGML_USE_CLBLAST)
-                const float * x = wdata;
-                const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
-
-                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
-
-                // zT = y * xT
-                ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
-                        ne11, ne01, ne10,
-                        1.0f,    y, ne10,
-                                 x, ne10,
-                        0.0f,    d, ne01,
-                        GGML_TYPE_F32);
-#else
                 const float * x = wdata;
                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
 
@@ -9768,7 +10739,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
                         1.0f,    y, ne10,
                                  x, ne00,
                         0.0f,    d, ne01);
-#endif
             }
         }
 
@@ -9924,16 +10894,16 @@ static void ggml_compute_forward_mul_mat_q_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_CUBLAS)
-    if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
+#if defined(GGML_USE_CLBLAST)
+    if (ggml_cl_can_mul_mat(src0, src1, dst)) {
         if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
-            ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
+            ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
         }
         return;
     }
 #endif
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
         if (params->ith != 0) {
             return;
@@ -9956,9 +10926,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
 
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
-#if defined(GGML_USE_CLBLAST)
-                const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
-#else
                 {
                     size_t id = 0;
                     for (int64_t i01 = 0; i01 < ne01; ++i01) {
@@ -9970,23 +10937,12 @@ static void ggml_compute_forward_mul_mat_q_f32(
                 }
 
                 const float * x = wdata;
-#endif
 
-#if defined(GGML_USE_CLBLAST)
-                // zT = y * xT
-                ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
-                        ne11, ne01, ne10,
-                        1.0f,    y, ne10,
-                                 x, ne10,
-                        0.0f,    d, ne01,
-                        type);
-#else
                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
                         ne11, ne01, ne10,
                         1.0f,    y, ne10,
                                  x, ne00,
                         0.0f,    d, ne01);
-#endif
             }
         }
 
@@ -10081,6 +11037,11 @@ static void ggml_compute_forward_mul_mat(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
             {
                 ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
             } break;
@@ -10099,6 +11060,176 @@ static void ggml_compute_forward_mul_mat(
     }
 }
 
+// ggml_compute_forward_out_prod
+
+
+static void ggml_compute_forward_out_prod_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
+
+    const int64_t ne10 = src1->ne[0];
+    //const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
+
+    const int64_t ne0  = dst->ne[0];
+    const int64_t ne1  = dst->ne[1];
+    const int64_t ne2  = dst->ne[2];
+    const int64_t ne3  = dst->ne[3];
+
+    const int nb00 = src0->nb[0];
+    const int nb01 = src0->nb[1];
+    const int nb02 = src0->nb[2];
+    const int nb03 = src0->nb[3];
+
+    const int nb10 = src1->nb[0];
+    const int nb11 = src1->nb[1];
+    const int nb12 = src1->nb[2];
+    const int nb13 = src1->nb[3];
+
+    const int nb0  = dst->nb[0];
+    const int nb1  = dst->nb[1];
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_ASSERT(ne02 == ne12);
+    GGML_ASSERT(ne03 == ne13);
+    GGML_ASSERT(ne2  == ne12);
+    GGML_ASSERT(ne3  == ne13);
+
+    // we don't support permuted src0 or src1
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    // GGML_ASSERT(nb0 <= nb1);
+    // GGML_ASSERT(nb1 <= nb2);
+    // GGML_ASSERT(nb2 <= nb3);
+
+    GGML_ASSERT(ne0 == ne00);
+    GGML_ASSERT(ne1 == ne10);
+    GGML_ASSERT(ne2 == ne02);
+    GGML_ASSERT(ne3 == ne03);
+
+    // nb01 >= nb00 - src0 is not transposed
+    //   compute by src0 rows
+
+    // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
+    // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
+
+    if (params->type == GGML_TASK_INIT) {
+        ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // parallelize by last three dimensions
+
+    // total rows in dst
+    const int64_t nr = ne1*ne2*ne3;
+
+    // rows per thread
+    const int64_t dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int64_t ir0 = dr*ith;
+    const int64_t ir1 = MIN(ir0 + dr, nr);
+
+    // dst[:,:,:,:] = 0
+    // for i2,i3:
+    //   for i1:
+    //     for i01:
+    //       for i0:
+    //         dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
+
+    for (int64_t ir = ir0; ir < ir1; ++ir) {
+        // dst indices
+        const int64_t i3 = ir/(ne2*ne1);
+        const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
+        const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+        const int64_t i02 = i2;
+        const int64_t i03 = i3;
+
+        //const int64_t i10 = i1;
+        const int64_t i12 = i2;
+        const int64_t i13 = i3;
+
+        for (int64_t i01 = 0; i01 < ne01; ++i01) {
+            const int64_t i11 = i01;
+
+            float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
+            float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+            float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
+
+            ggml_vec_mad_f32(ne0, d, s0, *s1);
+            // for (int64_t i0 = 0; i0 < ne0; ++i0) {
+            //     d[i0] += s0[i0] * s1[i1];
+            // }
+        }
+    }
+
+    //int64_t t1 = ggml_perf_time_us();
+    //static int64_t acc = 0;
+    //acc += t1 - t0;
+    //if (t1 - t0 > 10) {
+    //    printf("\n");
+    //    printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
+    //    printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
+    //    printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
+    //    printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
+
+    //    printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
+    //}
+}
+
+static void ggml_compute_forward_out_prod(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q8_1:
+            {
+                GGML_ASSERT(false); // todo
+                // ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                GGML_ASSERT(false); // todo
+                // ggml_compute_forward_out_prod_f16_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_out_prod_f32(params, src0, src1, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_scale
 
 static void ggml_compute_forward_scale_f32(
@@ -10221,7 +11352,7 @@ static void ggml_compute_forward_set_f32(
     const int im2 = (ne12 == 0 ? 0 : ne12-1);
     const int im3 = (ne13 == 0 ? 0 : ne13-1);
 
-    GGML_ASSERT(offset + im0*nb0  + im1*nb1  + im2*nb2  + im3*nb3  < ggml_nbytes(dst));
+    GGML_ASSERT(offset + im0*nb0  + im1*nb1  + im2*nb2  + im3*nb3  <= ggml_nbytes(dst));
 
     GGML_ASSERT(nb10 == sizeof(float));
 
@@ -10264,6 +11395,11 @@ static void ggml_compute_forward_set(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
         default:
             {
                 GGML_ASSERT(false);
@@ -10429,6 +11565,11 @@ static void ggml_compute_forward_get_rows(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
             {
                 ggml_compute_forward_get_rows_q(params, src0, src1, dst);
             } break;
@@ -10511,7 +11652,11 @@ static void ggml_compute_forward_get_rows_back_f32(
     GGML_ASSERT(ggml_is_contiguous(opt0));
     GGML_ASSERT(ggml_is_contiguous(dst));
 
-    ggml_compute_forward_dup_same_cont(params, opt0, dst);
+    // ggml_compute_forward_dup_same_cont(params, opt0, dst);
+
+    if (params->type == GGML_TASK_INIT) {
+        memset(dst->data, 0, ggml_nbytes(dst));
+    }
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -10655,8 +11800,8 @@ static void ggml_compute_forward_diag_mask_f32(
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst,
         const float value) {
-    assert(src1->type == GGML_TYPE_I32);
-    assert(ggml_nelements(src1) == 2);
+    GGML_ASSERT(src1->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_nelements(src1) == 2);
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -10664,7 +11809,7 @@ static void ggml_compute_forward_diag_mask_f32(
     const int  n_past  =       ((int32_t *) src1->data)[0];
     const bool inplace = (bool)((int32_t *) src1->data)[1];
 
-    assert(n_past >= 0);
+    GGML_ASSERT(n_past >= 0);
 
     if (!inplace && (params->type == GGML_TASK_INIT)) {
         // memcpy needs to be synchronized across threads to avoid race conditions.
@@ -10688,8 +11833,8 @@ static void ggml_compute_forward_diag_mask_f32(
     const int nr = src0->ne[1];
     const int nz = n/nr;
 
-    assert( dst->nb[0] == sizeof(float));
-    assert(src0->nb[0] == sizeof(float));
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
 
     for (int k = 0; k < nz; k++) {
         for (int j = ith; j < nr; j += nth) {
@@ -10825,39 +11970,135 @@ static void ggml_compute_forward_soft_max(
     }
 }
 
-// ggml_compute_forward_alibi
+// ggml_compute_forward_soft_max_back
 
-static void ggml_compute_forward_alibi_f32(
+static void ggml_compute_forward_soft_max_back_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    assert(params->ith == 0);
-    assert(src1->type == GGML_TYPE_I32);
-    assert(ggml_nelements(src1) == 3);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_are_same_shape(src1, dst));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
-    const int   n_past   = ((int32_t *) src1->data)[0];
-    const int   n_head   = ((int32_t *) src1->data)[1];
-    const float max_bias = ((float *)   src1->data)[2];
-
-    assert(n_past >= 0);
+    // TODO: handle transposed/permuted matrices
 
-    const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
-    const int ne1 = src0->ne[1]; // seq_len_without_past
-    //const int ne2 = src0->ne[2]; // n_head -> this is k
-    //const int ne3 = src0->ne[3]; // 1 -> bsz
+    const int ith = params->ith;
+    const int nth = params->nth;
 
-    const int n  = ggml_nrows(src0);
-    const int ne2_ne3 = n/ne1; // ne2*ne3
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
 
-    const int nb0 = src0->nb[0];
-    const int nb1 = src0->nb[1];
-    const int nb2 = src0->nb[2];
-    //const int nb3 = src0->nb[3];
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
+        float *y  = (float *)((char *) src1->data + i1*src1->nb[1]);
+        float *dx = (float *)((char *) dst->data  + i1*dst->nb[1]);
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            //printf("p[%d] = %f\n", i, p[i]);
+            assert(!isnan(dy[i]));
+            assert(!isnan(y[i]));
+        }
+#endif
+        // Jii = yi - yi*yi
+        // Jij = -yi*yj
+        // J = diag(y)-y.T*y
+        // dx = J * dy
+        // dxk = sum_i(Jki * dyi)
+        // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
+        // dxk = sum_i(-yk*yi * dyi) + yk*dyk
+        // dxk = -yk * sum_i(yi * dyi) + yk*dyk
+        // dxk = -yk * dot(y, dy) + yk*dyk
+        // dxk = yk * (- dot(y, dy) + dyk)
+        // dxk = yk * (dyk - dot(y, dy))
+        //
+        // post-order:
+        // dot_y_dy := dot(y, dy)
+        // dx := dy
+        // dx := dx - dot_y_dy
+        // dx := dx * y
+
+        // linear runtime, no additional memory
+        float dot_y_dy = 0;
+        ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy);
+        ggml_vec_cpy_f32 (nc, dx, dy);
+        ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
+        ggml_vec_mul_f32 (nc, dx, dx, y);
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            assert(!isnan(dx[i]));
+            assert(!isinf(dx[i]));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_soft_max_back(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_soft_max_back_f32(params, src0, src1, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_alibi
+
+static void ggml_compute_forward_alibi_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+
+    GGML_ASSERT(src1->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_nelements(src1) == 3);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int   n_past   = ((int32_t *) src1->data)[0];
+    const int   n_head   = ((int32_t *) src1->data)[1];
+    const float max_bias = ((float *)   src1->data)[2];
+
+    assert(n_past >= 0);
+
+    const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
+    const int ne1 = src0->ne[1]; // seq_len_without_past
+    //const int ne2 = src0->ne[2]; // n_head -> this is k
+    //const int ne3 = src0->ne[3]; // 1 -> bsz
+
+    const int n  = ggml_nrows(src0);
+    const int ne2_ne3 = n/ne1; // ne2*ne3
+
+    const int nb0 = src0->nb[0];
+    const int nb1 = src0->nb[1];
+    const int nb2 = src0->nb[2];
+    //const int nb3 = src0->nb[3];
 
     assert(nb0 == sizeof(float));
     assert(ne1 + n_past == ne0); (void) n_past;
@@ -10897,8 +12138,9 @@ static void ggml_compute_forward_alibi_f16(
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     assert(params->ith == 0);
-    assert(src1->type == GGML_TYPE_I32);
-    assert(ggml_nelements(src1) == 3);
+
+    GGML_ASSERT(src1->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_nelements(src1) == 3);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -10975,6 +12217,12 @@ static void ggml_compute_forward_alibi(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_Q8_K:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
@@ -10994,15 +12242,16 @@ static void ggml_compute_forward_clamp_f32(
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     assert(params->ith == 0);
-    assert(src1->type == GGML_TYPE_I32);
-    assert(ggml_nelements(src1) == 2);
+
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_nelements(src1) == 2);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
-    const int min = ((float *) src1->data)[0];
-    const int max = ((float *) src1->data)[1];
+    const float min = ((float *) src1->data)[0];
+    const float max = ((float *) src1->data)[1];
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -11046,6 +12295,12 @@ static void ggml_compute_forward_clamp(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_Q8_K:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
@@ -11135,7 +12390,7 @@ static void ggml_compute_forward_rope_f32(
                         theta *= theta_scale;
 
                         const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                              float * dst_data  = (float *)((char *)  dst->data +  i3*nb3 + i2*nb2  + i1*nb1  + i0*nb0);
+                              float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
                         const float x0 = src[0];
                         const float x1 = src[1];
@@ -11156,7 +12411,7 @@ static void ggml_compute_forward_rope_f32(
                             const int64_t i0 = ib*n_dims + ic/2;
 
                             const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                                  float * dst_data  = (float *)((char *)  dst->data +  i3*nb3 + i2*nb2  + i1*nb1  + i0*nb0);
+                                  float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
                             const float x0 = src[0];
                             const float x1 = src[n_dims/2];
@@ -11554,9 +12809,9 @@ static void ggml_compute_forward_rope_back(
     }
 }
 
-// ggml_compute_forward_conv_1d_1s
+// ggml_compute_forward_conv_1d_s1_ph
 
-static void ggml_compute_forward_conv_1d_1s_f16_f32(
+static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -11676,7 +12931,7 @@ static void ggml_compute_forward_conv_1d_1s_f16_f32(
     }
 }
 
-static void ggml_compute_forward_conv_1d_1s_f32(
+static void ggml_compute_forward_conv_1d_s1_ph_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -11796,7 +13051,7 @@ static void ggml_compute_forward_conv_1d_1s_f32(
     }
 }
 
-static void ggml_compute_forward_conv_1d_1s(
+static void ggml_compute_forward_conv_1d_s1_ph(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -11804,11 +13059,11 @@ static void ggml_compute_forward_conv_1d_1s(
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
-                ggml_compute_forward_conv_1d_1s_f16_f32(params, src0, src1, dst);
+                ggml_compute_forward_conv_1d_s1_ph_f16_f32(params, src0, src1, dst);
             } break;
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst);
+                ggml_compute_forward_conv_1d_s1_ph_f32(params, src0, src1, dst);
             } break;
         default:
             {
@@ -11817,9 +13072,9 @@ static void ggml_compute_forward_conv_1d_1s(
     }
 }
 
-// ggml_compute_forward_conv_1d_2s
+// ggml_compute_forward_conv_1d_s2_ph
 
-static void ggml_compute_forward_conv_1d_2s_f16_f32(
+static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -11939,7 +13194,7 @@ static void ggml_compute_forward_conv_1d_2s_f16_f32(
     }
 }
 
-static void ggml_compute_forward_conv_1d_2s_f32(
+static void ggml_compute_forward_conv_1d_s2_ph_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -12059,7 +13314,143 @@ static void ggml_compute_forward_conv_1d_2s_f32(
     }
 }
 
-static void ggml_compute_forward_conv_1d_2s(
+static void ggml_compute_forward_conv_1d_s2_ph(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_conv_1d_s2_ph_f16_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_conv_1d_s2_ph_f32(params, src0, src1, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_conv_2d_sk_p0
+
+static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    //const int ne03 = src0->ne[3];
+
+    const int ne10 = src1->ne[0];
+    //const int ne11 = src1->ne[1];
+    const int ne12 = src1->ne[2];
+    //const int ne13 = src1->ne[3];
+
+    const int ne0  = dst->ne[0];
+    const int ne1  = dst->ne[1];
+    const int ne2  = dst->ne[2];
+    //const int ne3  = dst->ne[3];
+    //const int ne   = ne0*ne1*ne2*ne3;
+
+    const int nb00 = src0->nb[0];
+    //const int nb01 = src0->nb[1];
+    //const int nb02 = src0->nb[2];
+    const int nb03 = src0->nb[3];
+
+    const int nb10 = src1->nb[0];
+    //const int nb11 = src1->nb[1];
+    const int nb12 = src1->nb[2];
+    //const int nb13 = src1->nb[3];
+
+    //const int nb0  = dst->nb[0];
+    //const int nb1  = dst->nb[1];
+    const int nb2  = dst->nb[2];
+    //const int nb3  = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nk0 = ne00;
+    const int nk1 = ne01;
+
+    // size of the convolution row - the kernel size unrolled across all channels
+    // round-up so it is more suitable for SIMD
+    const int ew0 = ggml_up32(nk0*nk1*ne02);
+
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    if (params->type == GGML_TASK_INIT) {
+        // TODO: fix this memset (wsize is overestimated)
+        memset(params->wdata, 0, params->wsize);
+
+        // prepare source data (src1)
+        {
+            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+
+            for (int i12 = 0; i12 < ne12; i12++) {
+                const float * const src = (float *)((char *) src1->data + i12*nb12);
+                ggml_fp16_t * dst_data = wdata;
+
+                for (int i1 = 0; i1 < ne1; i1++) {
+                    for (int i0 = 0; i0 < ne0; i0++) {
+                        for (int ik1 = 0; ik1 < nk1; ik1++) {
+                            for (int ik0 = 0; ik0 < nk0; ik0++) {
+                                dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
+                                    GGML_FP32_TO_FP16(src[(i1*nk1 + ik1)*ne10 + (i0*nk0 + ik0)]);
+                            }
+                        }
+                    }
+                }
+            }
+        }
+
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // total patches in dst
+    const int np = ne2;
+
+    // patches per thread
+    const int dp = (np + nth - 1)/nth;
+
+    // patch range for this thread
+    const int ip0 = dp*ith;
+    const int ip1 = MIN(ip0 + dp, np);
+
+    ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+
+    for (int i2 = ip0; i2 < ip1; i2++) {
+        float * dst_data = (float *)((char *) dst->data + i2*nb2);
+
+        for (int i1 = 0; i1 < ne1; ++i1) {
+            for (int i0 = 0; i0 < ne0; ++i0) {
+                ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0,
+                        (ggml_fp16_t *) ((char *) src0->data + i2*nb03),
+                        (ggml_fp16_t *)                wdata + (i1*ne0 + i0)*ew0);
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_conv_2d_sk_p0(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -12067,11 +13458,12 @@ static void ggml_compute_forward_conv_1d_2s(
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
-                ggml_compute_forward_conv_1d_2s_f16_f32(params, src0, src1, dst);
+                ggml_compute_forward_conv_2d_sk_p0_f16_f32(params, src0, src1, dst);
             } break;
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst);
+                //ggml_compute_forward_conv_2d_sk_p0_f32(params, src0, src1, dst);
+                GGML_ASSERT(false);
             } break;
         default:
             {
@@ -12766,91 +14158,1025 @@ static void ggml_compute_forward_flash_ff(
     }
 }
 
-// ggml_compute_forward_map_unary
+// ggml_compute_forward_flash_attn_back
 
-static void ggml_compute_forward_map_unary_f32(
+static void ggml_compute_forward_flash_attn_back_f32(
         const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        struct ggml_tensor * dst,
-        const ggml_unary_op_f32_t fun) {
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+        const struct ggml_tensor * q,
+        const struct ggml_tensor * k,
+        const struct ggml_tensor * v,
+        const struct ggml_tensor * d,
+        const bool masked,
+              struct ggml_tensor * dst) {
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
 
-    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
-        return;
-    }
+    const int64_t neq0 = q->ne[0];
+    const int64_t neq1 = q->ne[1];
+    const int64_t neq2 = q->ne[2];
+    const int64_t neq3 = q->ne[3];
 
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
+    const int64_t nek0 = k->ne[0];
+    const int64_t nek1 = k->ne[1];
+    //const int64_t nek2 = k->ne[2];
+    //const int64_t nek3 = k->ne[3];
 
-    assert( dst->nb[0] == sizeof(float));
-    assert(src0->nb[0] == sizeof(float));
+    const int64_t nev0 = v->ne[0];
+    const int64_t nev1 = v->ne[1];
+    //const int64_t nev2 = v->ne[2];
+    //const int64_t nev3 = v->ne[3];
 
-    for (int i = 0; i < n; i++) {
-        fun(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
+    const int64_t ned0 = d->ne[0];
+    const int64_t ned1 = d->ne[1];
+    //const int64_t ned2 = d->ne[2];
+    //const int64_t ned3 = d->ne[3];
 
+    const int64_t ne0  = dst->ne[0];
+    const int64_t ne1  = dst->ne[1];
+    const int64_t ne2  = dst->ne[2];
+    const int64_t ne3  = dst->ne[3];
 
-static void ggml_compute_forward_map_unary(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        struct ggml_tensor * dst,
-        const ggml_unary_op_f32_t fun) {
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
-            } break;
-        default:
-            {
-                GGML_ASSERT(false);
-            } break;
-    }
-}
+    const int nbk0 = k->nb[0];
+    const int nbk1 = k->nb[1];
+    const int nbk2 = k->nb[2];
+    const int nbk3 = k->nb[3];
 
-// ggml_compute_forward_map_binary
+    const int nbq0 = q->nb[0];
+    const int nbq1 = q->nb[1];
+    const int nbq2 = q->nb[2];
+    const int nbq3 = q->nb[3];
 
-static void ggml_compute_forward_map_binary_f32(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-        struct ggml_tensor * dst,
-        const ggml_binary_op_f32_t fun) {
-    assert(params->ith == 0);
-    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+    const int nbv0 = v->nb[0];
+    const int nbv1 = v->nb[1];
+    const int nbv2 = v->nb[2];
+    const int nbv3 = v->nb[3];
 
-    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
-        return;
-    }
+    const int nbd0 = d->nb[0];
+    const int nbd1 = d->nb[1];
+    const int nbd2 = d->nb[2];
+    const int nbd3 = d->nb[3];
 
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
+    const int nb0  = dst->nb[0];
+    const int nb1  = dst->nb[1];
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t D = neq0;
+    const int64_t N = neq1;
+    const int64_t P = nek1 - N;
+    const int64_t M = P + N;
+
+    const int Mup  = ggml_up(M, GGML_SOFT_MAX_UNROLL);
+    const int mxDM = MAX(D, Mup);
+
+    // GGML_ASSERT(ne0 == D);
+    // GGML_ASSERT(ne1 == N);
+    GGML_ASSERT(P >= 0);
+
+    GGML_ASSERT(nbq0 == sizeof(float));
+    GGML_ASSERT(nbk0 == sizeof(float));
+    GGML_ASSERT(nbv0 == sizeof(float));
+
+    GGML_ASSERT(neq0 == D);
+    GGML_ASSERT(nek0 == D);
+    GGML_ASSERT(nev1 == D);
+    GGML_ASSERT(ned0 == D);
+
+    GGML_ASSERT(neq1 == N);
+    GGML_ASSERT(nek1 == N + P);
+    GGML_ASSERT(nev1 == D);
+    GGML_ASSERT(ned1 == N);
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    if (params->type == GGML_TASK_INIT) {
+        if (ith == 0) {
+            memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
+        }
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // parallelize by q rows using ggml_vec_dot_f32
+
+    // total rows in q
+    const int nr = neq2*neq3;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    const float scale = 1.0f/sqrtf(D);
+
+    //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // q indices
+        const int iq3 = ir/(neq2);
+        const int iq2 = ir - iq3*neq2;
+        for ( int iq1 = 0; iq1 < neq1; ++iq1) {
+
+
+            // not sure about CACHE_LINE_SIZE_F32..
+            // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
+            float * S  = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
+            float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
+
+            for (int i = M; i < Mup; ++i) {
+                S[i] = -INFINITY;
+            }
+
+            for (int64_t ic = 0; ic < nek1; ++ic) {
+                // k indices
+                const int ik3 = iq3;
+                const int ik2 = iq2;
+                const int ik1 = ic;
+
+                // S indices
+                const int i1 = ik1;
+
+                ggml_vec_dot_f32(neq0,
+                        S + i1,
+                        (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
+                        (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
+            }
+
+            // scale
+            ggml_vec_scale_f32(nek1, S, scale);
+
+            if (masked) {
+                for (int64_t i = P; i < M; i++) {
+                    if (i > P + iq1) {
+                        S[i] = -INFINITY;
+                    }
+                }
+            }
+
+            // softmax
+            {
+                float max = -INFINITY;
+                ggml_vec_max_f32(M, &max, S);
+
+                ggml_float sum = 0.0;
+                {
+#ifdef GGML_SOFT_MAX_ACCELERATE
+                    max = -max;
+                    vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
+                    vvexpf(SM, SM, &Mup);
+                    ggml_vec_sum_f32(Mup, &sum, SM);
+#else
+                    uint16_t   scvt[GGML_SOFT_MAX_UNROLL];
+                    ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
+
+                    for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
+                        float * SR =  S + i;
+                        float * SW = SM + i;
+
+                        for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
+                            if (SR[j] == -INFINITY) {
+                                SW[j] = 0.0f;
+                            } else {
+                                ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
+                                memcpy(&scvt[j], &s, sizeof(uint16_t));
+                                const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
+                                sump[j] += (ggml_float)val;
+                                SW[j] = val;
+                            }
+                        }
+                    }
+
+                    for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
+                        sum += sump[i];
+                    }
+#endif
+                }
+
+                assert(sum > 0.0);
+
+                sum = 1.0/sum;
+                ggml_vec_scale_f32(M, SM, sum);
+
+            }
+
+            // step-by-step explanation
+            {
+                // forward-process                   shape      grads from backward process
+                // parallel_for iq2,iq3:
+                //  k[:D,:M,:,:]                     [D,M,:,:]  grad[k][:D,:M,iq2,iq3]  += grad[kcur]
+                //  q[:D,:N,:,:]                     [D,N,:,:]  grad[q][:D,iq1,iq2,iq3] += grad[qcur]
+                //  v[:M,:D,:,:]                     [M,D,:,:]  grad[v][:M,:D,iq2,iq3]  += grad[vcur]
+                //  for iq1:
+                //   kcur   = k[:D,:M,iq2,iq3]       [D,M,1,1]  grad[kcur] = grad[S1].T @ qcur
+                //   qcur   = q[:D,iq1,iq2,iq3]      [D,1,1,1]  grad[qcur] = grad[S1]   @ kcur
+                //   vcur   = v[:M,:D,iq2,iq3]       [M,D,1,1]  grad[vcur] = grad[S5].T @ S4
+                //   S0     = -Inf                   [D,1,1,1]
+                //  ~S1[i]  = dot(kcur[:D,i], qcur)
+                //   S1     = qcur @ kcur.T          [M,1,1,1]  grad[S1]   = grad[S2] * scale
+                //   S2     = S1 * scale             [M,1,1,1]  grad[S2]   = diag_mask_zero(grad[S3], P)
+                //   S3     = diag_mask_inf(S2, P)   [M,1,1,1]  grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
+                //   S4     = softmax(S3)            [M,1,1,1]  grad[S4]   = grad[S5] @ vcur
+                //  ~S5[i]  = dot(vcur[:,i], S4)
+                //   S5     = S4 @ vcur.T            [D,1,1,1]  grad[S5]   = d[:D,iq1,iq2,iq3]
+                //  ~dst[i,iq1,iq2,iq3]  = S5[i]              ^
+                //   dst[:D,iq1,iq2,iq3] = S5                 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
+                // dst                               backward-/ grad[dst]                 = d
+                //
+                // output gradients with their dependencies:
+                //
+                // grad[kcur] = grad[S1].T @ qcur
+                // grad[S1]   = diag_mask_zero(grad[S3], P) * scale
+                // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
+                // grad[S4]   = grad[S5] @ vcur
+                // grad[S4]   = d[:D,iq1,iq2,iq3] @ vcur
+                // grad[qcur] = grad[S1]   @ kcur
+                // grad[vcur] = grad[S5].T @ S4
+                // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
+                //
+                // in post-order:
+                //
+                // S1         = qcur @ kcur.T
+                // S2         = S1 * scale
+                // S3         = diag_mask_inf(S2, P)
+                // S4         = softmax(S3)
+                // grad[S4]   = d[:D,iq1,iq2,iq3] @ vcur
+                // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
+                // grad[S1]   = diag_mask_zero(grad[S3], P) * scale
+                // grad[qcur] = grad[S1]   @ kcur
+                // grad[kcur] = grad[S1].T @ qcur
+                // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
+                //
+                // using less variables (SM=S4):
+                //
+                // S             = diag_mask_inf(qcur @ kcur.T * scale, P)
+                // SM            = softmax(S)
+                // S             = d[:D,iq1,iq2,iq3] @ vcur
+                // dot_SM_gradSM = dot(SM, S)
+                // S             = SM * (S - dot(SM, S))
+                // S             = diag_mask_zero(S, P) * scale
+                //
+                // grad[q][:D,iq1,iq2,iq3] += S   @ kcur
+                // grad[k][:D,:M,iq2,iq3]  += S.T @ qcur
+                // grad[v][:M,:D,iq2,iq3]  += d[:D,iq1,iq2,iq3].T @ SM
+            }
+
+            // S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
+            // S = d[:D,iq1,iq2,iq3] @ vcur
+            // S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
+            ggml_vec_set_f32(M, S, 0);
+            for (int64_t ic = 0; ic < D; ++ic) {
+                // dst indices
+                const int i1 = iq1;
+                const int i2 = iq2;
+                const int i3 = iq3;
+
+                ggml_vec_mad_f32(M,
+                        S,
+                         (float *) ((char *) v->data + (          ic*nbv1 + i2*nbv2 + i3*nbv3)),
+                        *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
+            }
+
+            // S = SM * (S - dot(SM, S))
+            float dot_SM_gradSM = 0;
+            ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
+            ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
+            ggml_vec_mul_f32 (M, S, S, SM);
+
+            // S = diag_mask_zero(S, P) * scale
+            if (masked) {
+                // for (int64_t i = P + iq1 + 1; i < M; i++) {
+                //     S[i] = 0;
+                // }
+                for (int64_t i = P; i < M; i++) {
+                    if (i > P + iq1) {
+                        S[i] = 0;
+                    }
+                }
+            }
+            ggml_vec_scale_f32(M, S, scale);
+
+            void * grad_q = (char *) dst->data;
+            void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
+            void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
+
+            const size_t nbgq1 = nb0*neq0;
+            const size_t nbgq2 = nb0*neq0*neq1;
+            const size_t nbgq3 = nb0*neq0*neq1*neq2;
+
+            const size_t nbgk1 = nb0*nek0;
+            const size_t nbgk2 = nb0*nek0*nek1;
+            const size_t nbgk3 = nb0*nek0*nek1*neq2;
+
+            const size_t nbgv1 = nb0*nev0;
+            const size_t nbgv2 = nb0*nev0*nev1;
+            const size_t nbgv3 = nb0*nev0*nev1*neq2;
+
+            // S    shape [M,1]
+            // SM   shape [M,1]
+            // kcur shape [D,M]
+            // qcur shape [D,1]
+            // vcur shape [M,D]
+            //
+            // grad[q][:D,iq1,iq2,iq3] += S @ kcur
+            // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
+            // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
+            //
+            //// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
+            //// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
+            for (int64_t ic = 0; ic < M; ++ic) {
+                // dst indices
+                const int i1 = iq1;
+                const int i2 = iq2;
+                const int i3 = iq3;
+
+                ggml_vec_mad_f32(D,
+                        (float *) ((char *) grad_q  + (i1*nbgq1  + i2*nbgq2  + i3*nbgq3)),
+                        (float *) ((char *) k->data + (ic*nbk1   + i2*nbk2   + i3*nbk3)),
+                        S[ic]);
+            }
+
+            // grad[k][:D,:M,iq2,iq3] += S.T       @ qcur
+            // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
+            // grad[k][:D,ic,iq2,iq3] += S[ic]     * qcur[:D,0]
+            for (int64_t ic = 0; ic < M; ++ic) {
+                // dst indices
+                const int i1 = iq1;
+                const int i2 = iq2;
+                const int i3 = iq3;
+
+                // ggml_vec_set_f32(D,
+                //         (float *) ((char *) grad_k  + (ic*nbgk1  + i2*nbgk2  + i3*nbgk3)),
+                //         0);
+                ggml_vec_mad_f32(D,
+                        (float *) ((char *) grad_k  + (ic*nbgk1  + i2*nbgk2  + i3*nbgk3)),
+                        (float *) ((char *) q->data + (i1*nbq1   + i2*nbq2   + i3*nbq3)),
+                        S[ic]);
+            }
+
+            // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T       @ SM
+            // grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M]
+            // grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3]         * SM[:M]
+            for (int64_t ic = 0; ic < D; ++ic) {
+                // dst indices
+                const int i1 = iq1;
+                const int i2 = iq2;
+                const int i3 = iq3;
+
+                // ggml_vec_set_f32(M,
+                //         (float *) ((char *) grad_v   + (          ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
+                //         0);
+                ggml_vec_mad_f32(M,
+                        (float *) ((char *) grad_v   + (          ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
+                        SM,
+                        *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1  + i2*nbd2  + i3*nbd3)));
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_flash_attn_back(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * q,
+        const struct ggml_tensor * k,
+        const struct ggml_tensor * v,
+        const struct ggml_tensor * d,
+        const bool masked,
+        struct ggml_tensor * dst) {
+    switch (q->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_flash_attn_back_f32(params, q, k, v, d, masked, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_win_part
+
+static void ggml_compute_forward_win_part_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * opt0,
+        struct ggml_tensor * dst) {
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int64_t ne00 = src0->ne[0]; UNUSED(ne00);
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3]; UNUSED(ne03);
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    const int64_t ne3 = dst->ne[3]; UNUSED(ne3);
+
+    const int32_t nep0 = ((const int32_t *)(opt0->data))[0];
+    const int32_t nep1 = ((const int32_t *)(opt0->data))[1];
+    const int32_t w    = ((const int32_t *)(opt0->data))[2];
+
+    assert(ne00 == ne0);
+    assert(ne3  == nep0*nep1);
+
+    // TODO: optimize / multi-thread
+    for (int py = 0; py < nep1; ++py) {
+        for (int px = 0; px < nep0; ++px) {
+            const int64_t i3 = py*nep0 + px;
+            for (int64_t i2 = 0; i2 < ne2; ++i2) {
+                for (int64_t i1 = 0; i1 < ne1; ++i1) {
+                    for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                        const int64_t i02 = py*w + i2;
+                        const int64_t i01 = px*w + i1;
+                        const int64_t i00 = i0;
+
+                        const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0    + i1*ne0   + i0;
+                        const int64_t j =                  i02*ne01*ne00 + i01*ne00 + i00;
+
+                        if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
+                            ((float *) dst->data)[i] = 0.0f;
+                        } else {
+                            ((float *) dst->data)[i] = ((float *) src0->data)[j];
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_win_part(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * opt0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_win_part_f32(params, src0, opt0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_win_unpart
+
+static void ggml_compute_forward_win_unpart_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * opt0,
+        struct ggml_tensor * dst) {
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    //const int64_t ne03 = src0->ne[3];
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+
+    const int32_t w = ((const int32_t *)(opt0->data))[0];
+
+    // padding
+    const int px = (w - ne1%w)%w;
+    //const int py = (w - ne2%w)%w;
+
+    const int npx = (px + ne1)/w;
+    //const int npy = (py + ne2)/w;
+
+    assert(ne0 == ne00);
+
+    // TODO: optimize / multi-thread
+    for (int64_t i2 = 0; i2 < ne2; ++i2) {
+        for (int64_t i1 = 0; i1 < ne1; ++i1) {
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                const int ip2 = i2/w;
+                const int ip1 = i1/w;
+
+                const int64_t i02 = i2%w;
+                const int64_t i01 = i1%w;
+                const int64_t i00 = i0;
+
+                const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00;
+                const int64_t j =                                  i2*ne1*ne0    + i1*ne0   + i0;
+
+                ((float *) dst->data)[j] = ((float *) src0->data)[i];
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_win_unpart(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * opt0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_win_unpart_f32(params, src0, opt0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_map_unary
+
+static void ggml_compute_forward_map_unary_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst,
+        const ggml_unary_op_f32_t fun) {
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert( dst->nb[0] == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        fun(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+
+static void ggml_compute_forward_map_unary(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst,
+        const ggml_unary_op_f32_t fun) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_map_binary
+
+static void ggml_compute_forward_map_binary_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst,
+        const ggml_binary_op_f32_t fun) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
 
     assert( dst->nb[0] == sizeof(float));
     assert(src0->nb[0] == sizeof(float));
     assert(src1->nb[0] == sizeof(float));
 
-    for (int i = 0; i < n; i++) {
-        fun(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])),
-                (float *) ((char *) src1->data + i*(src1->nb[1])));
+    for (int i = 0; i < n; i++) {
+        fun(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])),
+                (float *) ((char *) src1->data + i*(src1->nb[1])));
+    }
+}
+
+
+static void ggml_compute_forward_map_binary(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst,
+        const ggml_binary_op_f32_t fun) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_map_custom1
+
+static void ggml_compute_forward_map_custom1_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * a,
+        struct ggml_tensor * dst,
+        const ggml_custom1_op_f32_t fun) {
+    assert(params->ith == 0);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    fun(dst, a);
+}
+
+
+static void ggml_compute_forward_map_custom1(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * a,
+        struct ggml_tensor * dst,
+        const ggml_custom1_op_f32_t fun) {
+    switch (a->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_map_custom1_f32(params, a, dst, fun);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_map_custom2
+
+static void ggml_compute_forward_map_custom2_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * a,
+        const struct ggml_tensor * b,
+        struct ggml_tensor * dst,
+        const ggml_custom2_op_f32_t fun) {
+    assert(params->ith == 0);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    fun(dst, a, b);
+}
+
+
+static void ggml_compute_forward_map_custom2(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * a,
+        const struct ggml_tensor * b,
+        struct ggml_tensor * dst,
+        const ggml_custom2_op_f32_t fun) {
+    switch (a->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_map_custom2_f32(params, a, b, dst, fun);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_map_custom3
+
+static void ggml_compute_forward_map_custom3_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * a,
+        const struct ggml_tensor * b,
+        const struct ggml_tensor * c,
+        struct ggml_tensor * dst,
+        const ggml_custom3_op_f32_t fun) {
+    assert(params->ith == 0);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    fun(dst, a, b, c);
+}
+
+
+static void ggml_compute_forward_map_custom3(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * a,
+        const struct ggml_tensor * b,
+        const struct ggml_tensor * c,
+        struct ggml_tensor * dst,
+        const ggml_custom3_op_f32_t fun) {
+    switch (a->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_map_custom3_f32(params, a, b, c, dst, fun);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_cross_entropy_loss
+
+static void ggml_compute_forward_cross_entropy_loss_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_scalar(dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, src1));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    float * sums = (float *) params->wdata;
+
+    // TODO: handle transposed/permuted matrices
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    if (params->type == GGML_TASK_INIT) {
+        if (ith == 0) {
+            memset(sums, 0, sizeof(float) * (nth + nth * nc));
+        }
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        if (ith == 0) {
+            float * dp = (float *) dst->data;
+            ggml_vec_sum_f32(nth, dp, sums);
+            dp[0] *= -1.0f;
+        }
+        return;
+    }
+
+    const double eps = 1e-9;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
+        float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
+        float * st = (float *) params->wdata + nth + ith*nc;
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            //printf("p[%d] = %f\n", i, p[i]);
+            assert(!isnan(s0[i]));
+            assert(!isnan(s1[i]));
+        }
+#endif
+        // soft_max
+        ggml_float sum = 0.0;
+        {
+            float max = -INFINITY;
+            ggml_vec_max_f32(nc, &max, s0);
+
+            uint16_t scvt;
+            for (int i = 0; i < nc; i++) {
+                if (s0[i] == -INFINITY) {
+                    st[i] = 0.0f;
+                } else {
+                    // const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
+                    ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
+                    memcpy(&scvt, &s, sizeof(scvt));
+                    const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
+                    sum += (ggml_float)val;
+                    st[i] = val;
+                }
+            }
+
+            assert(sum > 0.0);
+            // sum = 1.0/sum;
+        }
+        // avoid log(0) by rescaling from [0..1] to [eps..1]
+        sum = (1.0 - eps) / sum;
+        ggml_vec_scale_f32(nc, st, sum);
+        ggml_vec_add1_f32(nc, st, st, eps);
+        ggml_vec_log_f32(nc, st, st);
+        ggml_vec_mul_f32(nc, st, st, s1);
+
+        ggml_vec_sum_f32(nc, sums + ith, st);
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            assert(!isnan(st[i]));
+            assert(!isinf(st[i]));
+        }
+#endif
+    }
+
+}
+
+static void ggml_compute_forward_cross_entropy_loss(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_cross_entropy_loss_f32(params, src0, src1, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
     }
 }
 
+// ggml_compute_forward_cross_entropy_loss_back
 
-static void ggml_compute_forward_map_binary(
+static void ggml_compute_forward_cross_entropy_loss_back_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        struct ggml_tensor * dst,
-        const ggml_binary_op_f32_t fun) {
+        const struct ggml_tensor * opt0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_contiguous(opt0));
+    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    const int64_t ith = params->ith;
+    const int64_t nth = params->nth;
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const float eps = 1e-9f;
+
+    // TODO: handle transposed/permuted matrices
+    const int64_t nc = src0->ne[0];
+    const int64_t nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int64_t dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int64_t ir0 = dr*ith;
+    const int64_t ir1 = MIN(ir0 + dr, nr);
+
+    float * d   = (float *) opt0->data;
+
+    for (int64_t i1 = ir0; i1 < ir1; i1++) {
+        float * ds0 = (float *)((char *) dst->data  + i1*dst->nb[1]);
+        float * s0  = (float *)((char *) src0->data + i1*src0->nb[1]);
+        float * s1  = (float *)((char *) src1->data + i1*src1->nb[1]);
+        float * sm  = (float *) params->wdata + ith*nc;
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            //printf("p[%d] = %f\n", i, p[i]);
+            assert(!isnan(s0[i]));
+            assert(!isnan(s1[i]));
+        }
+#endif
+        // step by step explanation:
+        {
+            //float * sums = (float *) params->wdata;
+
+            // forward pass with annotated gradients from backward pass
+            // (built by going in reverse operation order, adding to gradients of current operation args)
+            // st0 = exp(s0-max(s0))                                                       grad[st0] = grad[st1]*(1.0 - eps)/sum
+                                                          // from softmax_back:            grad[s0]  = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
+            // ggml_vec_scale_f32(nc, st, sum);           // st1 = st0*/sum = softmax(s0)  grad[st1] = grad[st2]*(1.0 - eps)
+            // ggml_vec_scale_f32(nc, st, (1.0f - eps));  // st2 = st1*(1.0 - eps)         grad[st2] = grad[st3]
+            // ggml_vec_add1_f32(nc, st, st, eps);        // st3 = st2 + eps               grad[st3] = grad[st4]/st3
+            // ggml_vec_log_f32(nc, st, st);              // st4 = log(st3)                grad[st4] = grad[st5] * s1
+            // ggml_vec_mul_f32(nc, st, st, s1);          // st5 = st4 * s1                grad[st5] = grad[sums[ith]]
+            // ggml_vec_sum_f32(nc, sums + ith, st);      // sums[ith] = st5               grad[sums[ith]] = grad[cross_entropy_loss] = -grad[cel]
+
+            // substitute into grad[st1], because we can reuse softmax_back from this point on
+            // grad[st1] = -grad[cel]*s1*(1.0 - eps)/(eps + softmax(s0)*(1.0 - eps))
+            // postorder:
+            // grad[st1] := softmax(s0)
+            // grad[st1] := grad[st1]*(1.0 - eps)
+            // grad[st1] := grad[st1] + eps
+            // grad[st1] := s1 / grad[st1]
+            // grad[st1] := grad[st1]*(1.0-eps)*-grad[cel]
+
+            // src0 gradients by going through softmax_back
+            // grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
+            //   from softmax_back:
+            //   dxk = yk * (dyk - dot(y, dy))
+            //   dot_y_dy := dot(y, dy)
+            //   dx := dy
+            //   dx := dx - dot_y_dy
+            //   dx := dx * y
+            //   postorder:
+            //   dot_st1_dst1 := dot(st1, grad[st1])
+            //   grad[s0] := grad[st1]
+            //   grad[s0] := grad[s0] - dot_st1_dst1
+            //   grad[s0] := grad[s0] * st1
+
+            // prepend postorder from grad[st1] directly using grad[s0] as memory location, as we will grad[s0] := grad[st1]
+            // sm           := softmax(s0)
+            // grad[s0]     := sm*(1.0 - eps)
+            // grad[s0]     := grad[s0] + eps
+            // grad[s0]     := s1 / grad[s0]
+            // grad[s0]     := grad[s0]*(1.0-eps)*-grad[cel]
+            // dot_st1_dst1 := dot(sm, grad[s0])
+            // grad[s0]     := grad[s0] - dot_st1_dst1
+            // grad[s0]     := grad[s0] * sm
+        }
+
+        // soft_max
+        ggml_float sum = 0.0;
+        {
+            float max = -INFINITY;
+            ggml_vec_max_f32(nc, &max, s0);
+
+            uint16_t scvt;
+            for (int i = 0; i < nc; i++) {
+                if (s0[i] == -INFINITY) {
+                    sm[i] = 0.0f;
+                } else {
+                    // const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
+                    ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
+                    memcpy(&scvt, &s, sizeof(scvt));
+                    const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
+                    sum += (ggml_float)val;
+                    sm[i] = val;
+                }
+            }
+
+            assert(sum > 0.0);
+            sum = 1.0/sum;
+        }
+
+        float dot_st1_dst1 = 0;
+        ggml_vec_scale_f32(nc, sm, sum);
+        ggml_vec_cpy_f32  (nc, ds0, sm);
+        ggml_vec_scale_f32(nc, ds0, (1.0f - eps));
+        ggml_vec_add1_f32 (nc, ds0, ds0, eps);
+        ggml_vec_div_f32  (nc, ds0, s1, ds0);
+        ggml_vec_scale_f32(nc, ds0, -(1.0f - eps)*d[0]);
+        ggml_vec_dot_f32  (nc, &dot_st1_dst1, sm, ds0);
+        ggml_vec_acc1_f32 (nc, ds0, -dot_st1_dst1);
+        ggml_vec_mul_f32  (nc, ds0, ds0, sm);
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            assert(!isnan(sm[i]));
+            assert(!isinf(sm[i]));
+            assert(!isnan(ds0[i]));
+            assert(!isinf(ds0[i]));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_cross_entropy_loss_back(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        const struct ggml_tensor * opt0,
+        struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
+                ggml_compute_forward_cross_entropy_loss_back_f32(params, src0, src1, opt0, dst);
             } break;
         default:
             {
@@ -12859,11 +15185,21 @@ static void ggml_compute_forward_map_binary(
     }
 }
 
+
 /////////////////////////////////
 
 static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
     GGML_ASSERT(params);
 
+#ifdef GGML_USE_CUBLAS
+    bool skip_cpu = ggml_cuda_compute_forward(params, tensor);
+    if (skip_cpu) {
+        return;
+    }
+    GGML_ASSERT(tensor->src0->backend == GGML_BACKEND_CPU);
+    GGML_ASSERT(tensor->src1 == NULL || tensor->src1->backend == GGML_BACKEND_CPU);
+#endif // GGML_USE_CUBLAS
+
     switch (tensor->op) {
         case GGML_OP_DUP:
             {
@@ -12921,6 +15257,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_repeat(params, tensor->src0, tensor);
             } break;
+        case GGML_OP_REPEAT_BACK:
+            {
+                ggml_compute_forward_repeat_back(params, tensor->src0, tensor);
+            } break;
         case GGML_OP_ABS:
             {
                 ggml_compute_forward_abs(params, tensor->src0, tensor);
@@ -12945,6 +15285,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_gelu(params, tensor->src0, tensor);
             } break;
+        case GGML_OP_GELU_QUICK:
+            {
+                ggml_compute_forward_gelu_quick(params, tensor->src0, tensor);
+            } break;
         case GGML_OP_SILU:
             {
                 ggml_compute_forward_silu(params, tensor->src0, tensor);
@@ -12969,6 +15313,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
             } break;
+        case GGML_OP_OUT_PROD:
+            {
+                ggml_compute_forward_out_prod(params, tensor->src0, tensor->src1, tensor);
+            } break;
         case GGML_OP_SCALE:
             {
                 ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor);
@@ -13025,6 +15373,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_soft_max(params, tensor->src0, tensor);
             } break;
+        case GGML_OP_SOFT_MAX_BACK:
+            {
+                ggml_compute_forward_soft_max_back(params, tensor->src0, tensor->src1, tensor);
+            } break;
         case GGML_OP_ROPE:
             {
                 ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
@@ -13041,25 +15393,44 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_clamp(params, tensor->src0, tensor->src1, tensor);
             } break;
-        case GGML_OP_CONV_1D_1S:
+        case GGML_OP_CONV_1D_S1_PH:
             {
-                ggml_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor);
+                ggml_compute_forward_conv_1d_s1_ph(params, tensor->src0, tensor->src1, tensor);
             } break;
-        case GGML_OP_CONV_1D_2S:
+        case GGML_OP_CONV_1D_S2_PH:
             {
-                ggml_compute_forward_conv_1d_2s(params, tensor->src0, tensor->src1, tensor);
+                ggml_compute_forward_conv_1d_s2_ph(params, tensor->src0, tensor->src1, tensor);
+            } break;
+        case GGML_OP_CONV_2D_SK_P0:
+            {
+                ggml_compute_forward_conv_2d_sk_p0(params, tensor->src0, tensor->src1, tensor);
             } break;
         case GGML_OP_FLASH_ATTN:
             {
-                int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
+                const int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
                 GGML_ASSERT(t == 0 || t == 1);
-                bool masked = t != 0;
+                const bool masked = t != 0;
                 ggml_compute_forward_flash_attn(params, tensor->src0, tensor->src1, tensor->opt[0], masked, tensor);
             } break;
         case GGML_OP_FLASH_FF:
             {
                 ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
             } break;
+        case GGML_OP_FLASH_ATTN_BACK:
+            {
+                int32_t t = ggml_get_i32_1d(tensor->opt[2], 0);
+                GGML_ASSERT(t == 0 || t == 1);
+                bool masked = t != 0;
+                ggml_compute_forward_flash_attn_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], masked, tensor);
+            } break;
+        case GGML_OP_WIN_PART:
+            {
+                ggml_compute_forward_win_part(params, tensor->src0, tensor->opt[0], tensor);
+            } break;
+        case GGML_OP_WIN_UNPART:
+            {
+                ggml_compute_forward_win_unpart(params, tensor->src0, tensor->opt[0], tensor);
+            } break;
         case GGML_OP_MAP_UNARY:
             {
                 const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
@@ -13072,6 +15443,34 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
                 ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
             }
             break;
+        case GGML_OP_MAP_CUSTOM1:
+            {
+                const ggml_custom1_op_f32_t fun = *((ggml_custom1_op_f32_t *)tensor->opt[0]->data);
+                ggml_compute_forward_map_custom1(params, tensor->src0, tensor, fun);
+            }
+            break;
+        case GGML_OP_MAP_CUSTOM2:
+            {
+                const ggml_custom2_op_f32_t fun = *((ggml_custom2_op_f32_t *)tensor->opt[0]->data);
+                ggml_compute_forward_map_custom2(params, tensor->src0, tensor->src1, tensor, fun);
+            }
+            break;
+        case GGML_OP_MAP_CUSTOM3:
+            {
+                const ggml_custom3_op_f32_t fun = *((ggml_custom3_op_f32_t *)tensor->opt[0]->data);
+                ggml_compute_forward_map_custom3(params, tensor->src0, tensor->src1, tensor->opt[1], tensor, fun);
+            }
+            break;
+        case GGML_OP_CROSS_ENTROPY_LOSS:
+            {
+                ggml_compute_forward_cross_entropy_loss(params, tensor->src0, tensor->src1, tensor);
+            }
+            break;
+        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+            {
+                ggml_compute_forward_cross_entropy_loss_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
+            }
+            break;
         case GGML_OP_NONE:
             {
                 // nop
@@ -13210,11 +15609,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src0->grad =
                         ggml_add_impl(ctx,
                                 src0->grad,
-                                ggml_mul(ctx,
-                                    tensor->grad, // this was not catched by test_grad because in test_grad tensor->grad is 1
+                                ggml_scale(ctx,
                                     ggml_div(ctx,
-                                        ggml_repeat(ctx, ggml_new_f32(ctx, 0.5f), tensor),
-                                        tensor)),
+                                        tensor->grad,
+                                        tensor),
+                                    ggml_new_f32(ctx, 0.5f)),
                                 inplace);
                 }
             } break;
@@ -13260,43 +15659,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 // necessary for llama
                 if (src0->grad) {
-                    GGML_ASSERT(src0->n_dims == 1 || src0->n_dims == 2);
-                    const int nc  = tensor->ne[0];
-                    const int nr  = tensor->ne[1];
-                    const int nc0 = src0->ne[0];
-                    const int nr0 = src0->ne[1];
-                    const int ncr = nc/nc0; // guaranteed to be an integer due to the check in ggml_can_repeat
-                    const int nrr = nr/nr0; // guaranteed to be an integer due to the check in ggml_can_repeat
-                    // tensor->grad [nc,nr,1,1]
-                    // reshape      [nc0,nc/nc0,nr0,nr/nr0]
-                    // permute      [nc0,nr0,nc/nc0,nr/nr0]
-                    // substitute   [nc0,nr0,ncr,nrr]
-                    // reshape      [nc0*nr0,ncr*nrr,1,1]
-                    // transpose    [ncr*nrr,nc0*nr0,1,1]
-                    // sum rows     [1,nc0*nr0,1,1]
-                    // transpose    [nc0*nr0,1,1]
-                    // reshape      [nc0,nr0,1,1] reshape_1d or reshape_2d
-                    // add to src0->grad
-
-                    int64_t ne[4]  = {nc0,ncr,nr0,nrr};
-
-                    struct ggml_tensor* F00 = tensor->grad;
-                    struct ggml_tensor* F01 = ggml_reshape   (ctx, F00, ggml_new_tensor(ctx,tensor->grad->type,4,ne));
-                    struct ggml_tensor* F02 = ggml_permute   (ctx, F01, 0,2,1,3);
-                    struct ggml_tensor* F03 = ggml_cont      (ctx, F02);
-                    struct ggml_tensor* F04 = ggml_reshape_2d(ctx, F03, nc0*nr0, ncr*nrr);
-                    struct ggml_tensor* F05 = ggml_transpose (ctx, F04);
-                    struct ggml_tensor* F06 = ggml_cont      (ctx, F05);
-                    struct ggml_tensor* F07 = ggml_sum_rows  (ctx, F06);
-                    struct ggml_tensor* F08 = ggml_transpose (ctx, F07);
-                    struct ggml_tensor* F09 = ggml_cont      (ctx, F08);
-                    struct ggml_tensor* F10 = ggml_reshape   (ctx, F09, src0->grad);
-
-                    src0->grad =
-                        ggml_add_impl(ctx,
-                                src0->grad,
-                                F10,
-                                inplace);
+                    src0->grad = ggml_add_impl(ctx,
+                            src0->grad,
+                            ggml_repeat_back(ctx, tensor->grad, src0->grad),
+                            inplace);
+                }
+            } break;
+        case GGML_OP_REPEAT_BACK:
+            {
+                if (src0->grad) {
+                    // TODO: test this
+                    src0->grad = ggml_add_impl(ctx,
+                            src0->grad,
+                            ggml_repeat(ctx, tensor->grad, src0->grad),
+                            inplace);
                 }
             } break;
         case GGML_OP_ABS:
@@ -13344,6 +15720,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_GELU_QUICK:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_ALIBI:
             {
                 GGML_ASSERT(false); // TODO: not implemented
@@ -13403,38 +15783,37 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
 
                 // necessary for llama
                 if (src0->grad) {
-                    // TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad);
                     src0->grad =
                         ggml_add_impl(ctx,
                                 src0->grad,
-                                // ds0 = dt.dot(s1.T)
-                                // ggml_out_prod(ctx, // [n,m]
-                                //     src1,          // [n,p]
-                                //     tensor->grad), // [m,p]
-                                // for now just using A*B==(B.T*A.T).T
-                                ggml_cont(ctx,                      // [n,m]
-                                    ggml_transpose(ctx,             // [n,m]
-                                        ggml_mul_mat(ctx,           // [m,n]
-                                            ggml_cont(ctx,          // [p,m]
-                                                ggml_transpose(ctx, // [p,m]
-                                                    tensor->grad)), // [m,p]
-                                            ggml_cont(ctx,          // [p,n]
-                                                ggml_transpose(ctx, // [p,n]
-                                                    src1))))),      // [n,p]
+                                ggml_out_prod(ctx, // [n,m]
+                                    src1,          // [n,p]
+                                    tensor->grad), // [m,p]
                                 inplace);
                 }
                 if (src1->grad) {
                     src1->grad =
                         ggml_add_impl(ctx,
                                 src1->grad,
-                                // ds1 = s0.T.dot(dt):
-                                ggml_mul_mat(ctx,                   // [n,p]
-                                    ggml_cont(ctx,                  // [m,n]
-                                        ggml_transpose(ctx, src0)), // [m,n]
-                                    tensor->grad),                  // [m,p]
+                                // ggml_mul_mat(ctx,                   // [n,p]
+                                //     ggml_cont(ctx,                  // [m,n]
+                                //         ggml_transpose(ctx, src0)), // [m,n]
+                                //     tensor->grad),                  // [m,p]
+
+                                // // when src0 is bigger than tensor->grad (this is mostly the case in llama),
+                                // // avoid transpose of src0, rather transpose smaller tensor->grad
+                                // // and then use ggml_out_prod
+                                ggml_out_prod(ctx,                  // [n,p]
+                                    src0,                           // [n,m]
+                                    ggml_transpose(ctx,             // [p,m]
+                                        tensor->grad)),             // [m,p]
                                 inplace);
                 }
             } break;
+        case GGML_OP_OUT_PROD:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_SCALE:
             {
                 // necessary for llama
@@ -13536,7 +15915,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 // necessary for llama
                 if (src0->grad) {
                     size_t offset;
-                    memcpy(&offset, tensor->padding, sizeof(offset));
+
+                    GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->opt[0]));
+                    memcpy(&offset, tensor->opt[0]->data, sizeof(offset));
 
                     size_t nb1     = tensor->nb[1];
                     size_t nb2     = tensor->nb[2];
@@ -13563,10 +15944,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 // necessary for llama
                 if (src0->grad) {
-                    int axis0 = tensor->padding[0] & 0x3;
-                    int axis1 = tensor->padding[1] & 0x3;
-                    int axis2 = tensor->padding[2] & 0x3;
-                    int axis3 = tensor->padding[3] & 0x3;
+                    int32_t * axes = (int32_t *) tensor->opt[0]->data;
+                    int axis0 = axes[0] & 0x3;
+                    int axis1 = axes[1] & 0x3;
+                    int axis2 = axes[2] & 0x3;
+                    int axis3 = axes[3] & 0x3;
                     int axes_backward[4] = {0,0,0,0};
                     axes_backward[axis0] = 0;
                     axes_backward[axis1] = 1;
@@ -13650,50 +16032,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 // necessary for llama
                 if (src0->grad) {
-                    // y = softmax(x)
-                    //
-                    // Jii = yi - yi*yi
-                    // Jij = -yi*yj
-                    // J = diag(y)-y.*y
-                    // dx = J * dy
-                    // dxk = sum(Jkj * dyk)
-
-                    int64_t ne2[4] = {
-                        tensor->ne[0],
-                        1,
-                        tensor->ne[1]*tensor->ne[2],
-                        tensor->ne[3]
-                    };
-                    struct ggml_tensor * tensor2 = ggml_cont(ctx,
-                        ggml_reshape_4d(ctx,
-                            ggml_cont(ctx, tensor),
-                            ne2[0], ne2[1], ne2[2], ne2[3]));
-
-                    struct ggml_tensor * grad2 = ggml_cont(ctx,
-                        ggml_reshape_4d(ctx,
-                            ggml_cont(ctx, tensor->grad),
-                            ne2[0], ne2[1], ne2[2], ne2[3]));
-
-                    struct ggml_tensor * tensor2_t = ggml_cont(ctx, // [1,ne0,ne1*ne2,ne3]
-                        ggml_permute(ctx,                           // [1,ne0,ne1*ne2,ne3]
-                            tensor2,                                // [ne0,1,ne1*ne2,ne3]
-                            1, 0, 2, 3));
-
                     src0->grad =
-                        ggml_add_impl(ctx,
-                            src0->grad,                   // [ne0,ne1,ne2,ne3]
-                            ggml_reshape(ctx,             // [ne0,ne1,ne2,ne3]
-                                ggml_mul_mat(ctx,         // [ne0,1,ne1*ne2,ne3]
-                                    ggml_sub(ctx,         // [ne0,ne0,ne1*ne2,ne3]
-                                        ggml_diag(ctx,    // [ne0,ne0,ne1*ne2,ne3]
-                                            tensor2),     // [ne0,1,ne1*ne2,ne3]
-                                        ggml_mul_mat(ctx, // [ne0,ne0,ne1*ne2,ne3]
-                                            tensor2_t,    // [1,ne0,ne1*ne2,ne3]
-                                            tensor2_t)),  // [1,ne0,ne1*ne2,ne3]
-                                    grad2),               // [ne0,1,ne1*ne2,ne3]
-                                src0->grad),
-                            inplace);
+                        ggml_add_impl(ctx, src0->grad,
+                            ggml_soft_max_back(ctx, tensor->grad, tensor),
+                        inplace);
                 }
+
+            } break;
+        case GGML_OP_SOFT_MAX_BACK:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
             } break;
         case GGML_OP_ROPE:
             {
@@ -13738,24 +16086,206 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     // noop
                 }
             } break;
-        case GGML_OP_CONV_1D_1S:
+        case GGML_OP_CONV_1D_S1_PH:
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
-        case GGML_OP_CONV_1D_2S:
+        case GGML_OP_CONV_1D_S2_PH:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
+        case GGML_OP_CONV_2D_SK_P0:
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
         case GGML_OP_FLASH_ATTN:
             {
-                GGML_ASSERT(false); // not supported
+                struct ggml_tensor * flash_grad = NULL;
+                if (src0->grad || src1->grad || tensor->opt[0]->grad) {
+                    int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
+                    GGML_ASSERT(t == 0 || t == 1);
+                    bool masked = t != 0;
+                    flash_grad =
+                        ggml_flash_attn_back(ctx,
+                            src0,
+                            src1,
+                            tensor->opt[0],
+                            tensor->grad,
+                            masked);
+                }
+
+                if (src0->grad) {
+                    struct ggml_tensor * grad_q = NULL;
+                    const size_t nb0    = flash_grad->nb[0];
+                    const size_t offset = 0;
+                    switch(src0->n_dims) {
+                        case 2:
+                            {
+                                grad_q = ggml_view_2d(ctx,
+                                    flash_grad,
+                                    src0->ne[0],
+                                    src0->ne[1],
+                                    nb0*src0->ne[0],
+                                    offset);
+                            } break;
+                        case 3:
+                            {
+                                grad_q = ggml_view_3d(ctx,
+                                    flash_grad,
+                                    src0->ne[0],
+                                    src0->ne[1],
+                                    src0->ne[2],
+                                    nb0*src0->ne[0],
+                                    nb0*src0->ne[0]*src0->ne[1],
+                                    offset);
+                            } break;
+                        case 4:
+                            {
+                                grad_q = ggml_view_4d(ctx,
+                                    flash_grad,
+                                    src0->ne[0],
+                                    src0->ne[1],
+                                    src0->ne[2],
+                                    src0->ne[3],
+                                    nb0*src0->ne[0],
+                                    nb0*src0->ne[0]*src0->ne[1],
+                                    nb0*src0->ne[0]*src0->ne[1]*src0->ne[2],
+                                    offset);
+                            } break;
+                    }
+
+                    src0->grad = ggml_add_impl(ctx,
+                            src0->grad,
+                            grad_q,
+                            inplace);
+                }
+
+                if (src1->grad) {
+                    struct ggml_tensor * grad_k = NULL;
+                    const size_t nb0    = flash_grad->nb[0];
+                    const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3];
+                    switch(src1->n_dims) {
+                        case 2:
+                            {
+                                grad_k = ggml_view_2d(ctx,
+                                    flash_grad,
+                                    src1->ne[0],
+                                    src1->ne[1],
+                                    nb0*src1->ne[0],
+                                    offset);
+                            } break;
+                        case 3:
+                            {
+                                grad_k = ggml_view_3d(ctx,
+                                    flash_grad,
+                                    src1->ne[0],
+                                    src1->ne[1],
+                                    src1->ne[2],
+                                    nb0*src1->ne[0],
+                                    nb0*src1->ne[0]*src1->ne[1],
+                                    offset);
+                            } break;
+                        case 4:
+                            {
+                                grad_k = ggml_view_4d(ctx,
+                                    flash_grad,
+                                    src1->ne[0],
+                                    src1->ne[1],
+                                    src1->ne[2],
+                                    src1->ne[3],
+                                    nb0*src1->ne[0],
+                                    nb0*src1->ne[0]*src1->ne[1],
+                                    nb0*src1->ne[0]*src1->ne[1]*src1->ne[2],
+                                    offset);
+                            } break;
+                    }
+
+                    src1->grad = ggml_add_impl(ctx,
+                            src1->grad,
+                            grad_k,
+                            inplace);
+                }
+
+                struct ggml_tensor * opt0 = tensor->opt[0];
+
+                if (opt0->grad) {
+                    struct ggml_tensor * grad_v = NULL;
+                    const size_t nb0    = flash_grad->nb[0];
+                    const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]
+                                        + nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3];
+                    switch(opt0->n_dims) {
+                        case 2:
+                            {
+                                grad_v = ggml_view_2d(ctx,
+                                    flash_grad,
+                                    opt0->ne[0],
+                                    opt0->ne[1],
+                                    nb0*opt0->ne[0],
+                                    offset);
+                            } break;
+                        case 3:
+                            {
+                                grad_v = ggml_view_3d(ctx,
+                                    flash_grad,
+                                    opt0->ne[0],
+                                    opt0->ne[1],
+                                    opt0->ne[2],
+                                    nb0*opt0->ne[0],
+                                    nb0*opt0->ne[0]*opt0->ne[1],
+                                    offset);
+                            } break;
+                        case 4:
+                            {
+                                grad_v = ggml_view_4d(ctx,
+                                    flash_grad,
+                                    opt0->ne[0],
+                                    opt0->ne[1],
+                                    opt0->ne[2],
+                                    opt0->ne[3],
+                                    nb0*opt0->ne[0],
+                                    nb0*opt0->ne[0]*opt0->ne[1],
+                                    nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2],
+                                    offset);
+                            } break;
+                    }
+
+                    opt0->grad = ggml_add_impl(ctx,
+                            opt0->grad,
+                            grad_v,
+                            inplace);
+                }
             } break;
         case GGML_OP_FLASH_FF:
             {
                 GGML_ASSERT(false); // not supported
             } break;
+        case GGML_OP_FLASH_ATTN_BACK:
+            {
+                GGML_ASSERT(false); // not supported
+            } break;
+        case GGML_OP_WIN_PART:
+        case GGML_OP_WIN_UNPART:
         case GGML_OP_MAP_UNARY:
         case GGML_OP_MAP_BINARY:
+        case GGML_OP_MAP_CUSTOM1:
+        case GGML_OP_MAP_CUSTOM2:
+        case GGML_OP_MAP_CUSTOM3:
+            {
+                GGML_ASSERT(false); // not supported
+            } break;
+        case GGML_OP_CROSS_ENTROPY_LOSS:
+            {
+                if (src0->grad) {
+                    src0->grad = ggml_add_impl(ctx,
+                                src0->grad,
+                                ggml_cross_entropy_loss_back(ctx,
+                                    src0,
+                                    src1,
+                                    tensor->grad),
+                                inplace);
+                }
+            } break;
+        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
             {
                 GGML_ASSERT(false); // not supported
             } break;
@@ -13810,11 +16340,19 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
         // reached a leaf node, not part of the gradient graph (e.g. a constant)
         GGML_ASSERT(cgraph->n_leafs < GGML_MAX_NODES);
 
+        if (strlen(node->name) == 0) {
+            ggml_format_name(node, "leaf_%d", cgraph->n_leafs);
+        }
+
         cgraph->leafs[cgraph->n_leafs] = node;
         cgraph->n_leafs++;
     } else {
         GGML_ASSERT(cgraph->n_nodes < GGML_MAX_NODES);
 
+        if (strlen(node->name) == 0) {
+            ggml_format_name(node, "node_%d", cgraph->n_nodes);
+        }
+
         cgraph->nodes[cgraph->n_nodes] = node;
         cgraph->grads[cgraph->n_nodes] = node->grad;
         cgraph->n_nodes++;
@@ -14127,6 +16665,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                 case GGML_OP_SUM_ROWS:
                 case GGML_OP_MEAN:
                 case GGML_OP_REPEAT:
+                case GGML_OP_REPEAT_BACK:
                 case GGML_OP_ABS:
                 case GGML_OP_SGN:
                 case GGML_OP_NEG:
@@ -14137,6 +16676,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     } break;
                 case GGML_OP_MUL:
                 case GGML_OP_GELU:
+                case GGML_OP_GELU_QUICK:
                 case GGML_OP_SILU:
                 case GGML_OP_SILU_BACK:
                 case GGML_OP_NORM:
@@ -14146,6 +16686,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                         node->n_tasks = n_threads;
                     } break;
                 case GGML_OP_MUL_MAT:
+                case GGML_OP_OUT_PROD:
                     {
                         node->n_tasks = n_threads;
 
@@ -14162,12 +16703,18 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                         if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
                             node->n_tasks = 1; // TODO: this actually is doing nothing
                                                 //       the threads are still spinning
-                            cur = ggml_cuda_mul_mat_get_wsize(node->src0, node->src1, node);
+                        }
+                        else
+#elif defined(GGML_USE_CLBLAST)
+                        if (ggml_cl_can_mul_mat(node->src0, node->src1, node)) {
+                            node->n_tasks = 1; // TODO: this actually is doing nothing
+                                                //       the threads are still spinning
+                            cur = ggml_cl_mul_mat_get_wsize(node->src0, node->src1, node);
                         }
                         else
 #endif
                         if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
                                 node->n_tasks = 1; // TODO: this actually is doing nothing
                                                    //       the threads are still spinning
@@ -14181,13 +16728,13 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
 #endif
                         } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
                             cur = 0;
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
                                 node->n_tasks = 1;
                             }
 #endif
                         } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
                                 node->n_tasks = 1;
                                 cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
@@ -14222,6 +16769,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     } break;
                 case GGML_OP_DIAG_MASK_INF:
                 case GGML_OP_SOFT_MAX:
+                case GGML_OP_SOFT_MAX_BACK:
                 case GGML_OP_ROPE:
                 case GGML_OP_ROPE_BACK:
                     {
@@ -14235,8 +16783,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     {
                         node->n_tasks = 1; //TODO
                     } break;
-                case GGML_OP_CONV_1D_1S:
-                case GGML_OP_CONV_1D_2S:
+                case GGML_OP_CONV_1D_S1_PH:
+                case GGML_OP_CONV_1D_S2_PH:
                     {
                         node->n_tasks = n_threads;
 
@@ -14263,6 +16811,41 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                             GGML_ASSERT(false);
                         }
 
+                        work_size = MAX(work_size, cur);
+                    } break;
+                case GGML_OP_CONV_2D_SK_P0:
+                    {
+                        node->n_tasks = n_threads;
+
+                        GGML_ASSERT(node->src1->ne[3] == 1);
+
+                        const int64_t ne00 = node->src0->ne[0]; // W
+                        const int64_t ne01 = node->src0->ne[1]; // H
+                        const int64_t ne02 = node->src0->ne[2]; // C
+                        const int64_t ne03 = node->src0->ne[3]; // N
+
+                        const int64_t ne10 = node->src1->ne[0]; // W
+                        const int64_t ne11 = node->src1->ne[1]; // H
+                        const int64_t ne12 = node->src1->ne[2]; // C
+
+                        const int64_t nk = ne00*ne01;
+
+                        UNUSED(ne02);
+                        UNUSED(ne03);
+                        UNUSED(nk);
+
+                        size_t cur = 0;
+
+                        if (node->src0->type == GGML_TYPE_F16 &&
+                            node->src1->type == GGML_TYPE_F32) {
+                            cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
+                        } else if (node->src0->type == GGML_TYPE_F32 &&
+                                   node->src1->type == GGML_TYPE_F32) {
+                            cur = sizeof(float)*      (ne10*ne11*ne12);
+                        } else {
+                            GGML_ASSERT(false);
+                        }
+
                         work_size = MAX(work_size, cur);
                     } break;
                 case GGML_OP_FLASH_ATTN:
@@ -14303,10 +16886,52 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
 
                         work_size = MAX(work_size, cur);
                     } break;
-                case GGML_OP_MAP_UNARY:
-                case GGML_OP_MAP_BINARY:
+                case GGML_OP_FLASH_ATTN_BACK:
+                    {
+                        node->n_tasks = n_threads;
+
+                        size_t cur = 0;
+
+                        const int64_t    D = node->src0->ne[0];
+                        const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
+                        const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
+                        if (node->src1->type == GGML_TYPE_F32) {
+                            cur  = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
+                            cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
+                        }
+
+                        if (node->src1->type == GGML_TYPE_F16) {
+                            cur  = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
+                            cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
+                        }
+
+                        work_size = MAX(work_size, cur);
+                    } break;
+                case GGML_OP_WIN_PART:
+                case GGML_OP_WIN_UNPART:
+                case GGML_OP_MAP_UNARY:
+                case GGML_OP_MAP_BINARY:
+                case GGML_OP_MAP_CUSTOM1:
+                case GGML_OP_MAP_CUSTOM2:
+                case GGML_OP_MAP_CUSTOM3:
+                    {
+                        node->n_tasks = 1;
+                    } break;
+                case GGML_OP_CROSS_ENTROPY_LOSS:
+                    {
+                        node->n_tasks = n_threads;
+
+                        size_t cur = ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks);
+
+                        work_size = MAX(work_size, cur);
+                    } break;
+                case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
                     {
-                        node->n_tasks = 1;
+                        node->n_tasks = n_threads;
+
+                        size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks;
+
+                        work_size = MAX(work_size, cur);
                     } break;
                 case GGML_OP_NONE:
                     {
@@ -14381,144 +17006,663 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                 workers[j].node = node;
             }
 
-            atomic_fetch_sub(&state_shared.n_ready, 1);
+            atomic_fetch_sub(&state_shared.n_ready, 1);
+
+            while (atomic_load(&state_shared.n_ready) > 0) {
+                ggml_lock_lock  (&state_shared.spin);
+                ggml_lock_unlock(&state_shared.spin);
+            }
+
+            atomic_store(&state_shared.has_work, true);
+        }
+
+        params.type = GGML_TASK_COMPUTE;
+        ggml_compute_forward(&params, node);
+
+        // wait for thread pool
+        if (node->n_tasks > 1) {
+            if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
+                atomic_store(&state_shared.has_work, false);
+            }
+
+            while (atomic_load(&state_shared.has_work)) {
+                ggml_lock_lock  (&state_shared.spin);
+                ggml_lock_unlock(&state_shared.spin);
+            }
+
+            atomic_fetch_sub(&state_shared.n_ready, 1);
+
+            while (atomic_load(&state_shared.n_ready) != 0) {
+                ggml_lock_lock  (&state_shared.spin);
+                ggml_lock_unlock(&state_shared.spin);
+            }
+        }
+
+        // FINALIZE
+        if (node->n_tasks > 1) {
+            if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
+                atomic_store(&state_shared.has_work, false);
+            }
+
+            while (atomic_load(&state_shared.has_work)) {
+                ggml_lock_lock  (&state_shared.spin);
+                ggml_lock_unlock(&state_shared.spin);
+            }
+
+            // launch thread pool
+            for (int j = 0; j < n_threads - 1; j++) {
+                workers[j].params = (struct ggml_compute_params) {
+                    .type  = GGML_TASK_FINALIZE,
+                    .ith   = j + 1,
+                    .nth   = node->n_tasks,
+                    .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
+                    .wdata = cgraph->work ? cgraph->work->data : NULL,
+                };
+                workers[j].node = node;
+            }
+
+            atomic_fetch_sub(&state_shared.n_ready, 1);
+
+            while (atomic_load(&state_shared.n_ready) > 0) {
+                ggml_lock_lock  (&state_shared.spin);
+                ggml_lock_unlock(&state_shared.spin);
+            }
+
+            atomic_store(&state_shared.has_work, true);
+        }
+
+        params.type = GGML_TASK_FINALIZE;
+        ggml_compute_forward(&params, node);
+
+        // wait for thread pool
+        if (node->n_tasks > 1) {
+            if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
+                atomic_store(&state_shared.has_work, false);
+            }
+
+            while (atomic_load(&state_shared.has_work)) {
+                ggml_lock_lock  (&state_shared.spin);
+                ggml_lock_unlock(&state_shared.spin);
+            }
+
+            atomic_fetch_sub(&state_shared.n_ready, 1);
+
+            while (atomic_load(&state_shared.n_ready) != 0) {
+                ggml_lock_lock  (&state_shared.spin);
+                ggml_lock_unlock(&state_shared.spin);
+            }
+        }
+
+        // performance stats (node)
+        {
+            int64_t perf_cycles_cur  = ggml_perf_cycles()  - perf_node_start_cycles;
+            int64_t perf_time_us_cur = ggml_perf_time_us() - perf_node_start_time_us;
+
+            node->perf_runs++;
+            node->perf_cycles  += perf_cycles_cur;
+            node->perf_time_us += perf_time_us_cur;
+        }
+    }
+
+    // join thread pool
+    if (n_threads > 1) {
+        atomic_store(&state_shared.stop, true);
+        atomic_store(&state_shared.has_work, true);
+
+        for (int j = 0; j < n_threads - 1; j++) {
+            int rc = ggml_thread_join(workers[j].thrd, NULL);
+            GGML_ASSERT(rc == 0);
+            UNUSED(rc);
+        }
+
+        ggml_lock_destroy(&state_shared.spin);
+    }
+
+    // performance stats (graph)
+    {
+        int64_t perf_cycles_cur  = ggml_perf_cycles()  - perf_start_cycles;
+        int64_t perf_time_us_cur = ggml_perf_time_us() - perf_start_time_us;
+
+        cgraph->perf_runs++;
+        cgraph->perf_cycles  += perf_cycles_cur;
+        cgraph->perf_time_us += perf_time_us_cur;
+
+        GGML_PRINT_DEBUG("%s: perf (%d) - cpu = %.3f / %.3f ms, wall = %.3f / %.3f ms\n",
+                __func__, cgraph->perf_runs,
+                (double) perf_cycles_cur      / (double) ggml_cycles_per_ms(),
+                (double) cgraph->perf_cycles  / (double) ggml_cycles_per_ms() / (double) cgraph->perf_runs,
+                (double) perf_time_us_cur     / 1000.0,
+                (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs);
+    }
+}
+
+void ggml_graph_reset(struct ggml_cgraph * cgraph) {
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        struct ggml_tensor * grad = cgraph->grads[i];
+
+        if (grad) {
+            ggml_set_zero(grad);
+        }
+    }
+}
+
+struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name) {
+    for (int i = 0; i < cgraph->n_leafs; i++) {
+        struct ggml_tensor * leaf = cgraph->leafs[i];
+
+        if (strcmp(leaf->name, name) == 0) {
+            return leaf;
+        }
+    }
+
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        struct ggml_tensor * node = cgraph->nodes[i];
+
+        if (strcmp(node->name, name) == 0) {
+            return node;
+        }
+    }
+
+    return NULL;
+}
+
+static void ggml_graph_export_leaf(const struct ggml_tensor * tensor, FILE * fout) {
+    const int64_t * ne = tensor->ne;
+    const size_t  * nb = tensor->nb;
+
+    fprintf(fout, "%-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n",
+            ggml_type_name(tensor->type),
+            ggml_op_name  (tensor->op),
+            tensor->n_dims,
+            ne[0], ne[1], ne[2], ne[3],
+            nb[0], nb[1], nb[2], nb[3],
+            tensor->data,
+            tensor->name);
+}
+
+static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char * arg, FILE * fout) {
+    const int64_t * ne = tensor->ne;
+    const size_t  * nb = tensor->nb;
+
+    fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %8d %16p %32s\n",
+            arg,
+            ggml_type_name(tensor->type),
+            ggml_op_name  (tensor->op),
+            tensor->n_dims,
+            ne[0], ne[1], ne[2], ne[3],
+            nb[0], nb[1], nb[2], nb[3],
+            tensor->n_tasks,
+            tensor->data,
+            tensor->name);
+}
+
+void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
+    //assert(cgraph->work      == NULL);
+    //assert(cgraph->work_size == 0);
+
+    uint64_t size_eval = 0;
+
+    // compute size of intermediate results
+    // TODO: does not take into account scratch buffers !!!!
+    for (int i = 0; i < cgraph->n_nodes; ++i) {
+        size_eval += ggml_nbytes(cgraph->nodes[i]);
+    }
+
+    // print
+    {
+        FILE * fout = stdout;
+
+        fprintf(fout, "\n");
+        fprintf(fout, "%-16s %8x\n", "magic",        GGML_FILE_MAGIC);
+        fprintf(fout, "%-16s %8d\n", "version",      GGML_FILE_VERSION);
+        fprintf(fout, "%-16s %8d\n", "leafs",        cgraph->n_leafs);
+        fprintf(fout, "%-16s %8d\n", "nodes",        cgraph->n_nodes);
+        fprintf(fout, "%-16s %" PRIu64 "\n", "eval", size_eval);
+
+        // header
+        fprintf(fout, "\n");
+        fprintf(fout, "%-6s %-12s %8s %8s %8s %8s %8s %16s %16s %16s %16s %16s %16s\n",
+                "TYPE", "OP", "NDIMS", "NE0", "NE1", "NE2", "NE3", "NB0", "NB1", "NB2", "NB3", "DATA", "NAME");
+
+        for (int i = 0; i < cgraph->n_leafs; ++i) {
+            ggml_graph_export_leaf(cgraph->leafs[i], fout);
+
+            GGML_ASSERT(cgraph->leafs[i]->op   == GGML_OP_NONE);
+            GGML_ASSERT(cgraph->leafs[i]->src0 == NULL);
+            GGML_ASSERT(cgraph->leafs[i]->src1 == NULL);
+        }
+
+        // header
+        fprintf(fout, "\n");
+        fprintf(fout, "%-6s %-6s %-12s %8s %8s %8s %8s %8s %16s %16s %16s %16s %8s %16s %16s\n",
+                "ARG", "TYPE", "OP", "NDIMS", "NE0", "NE1", "NE2", "NE3", "NB0", "NB1", "NB2", "NB3", "NTASKS", "DATA", "NAME");
+
+        for (int i = 0; i < cgraph->n_nodes; ++i) {
+            ggml_graph_export_node(cgraph->nodes[i], "DST", fout);
+
+            if (cgraph->nodes[i]->src0) {
+                ggml_graph_export_node(cgraph->nodes[i]->src0, "SRC0", fout);
+            }
+
+            if (cgraph->nodes[i]->src1) {
+                ggml_graph_export_node(cgraph->nodes[i]->src1, "SRC1", fout);
+            }
+
+            for (int j = 0; j < GGML_MAX_OPT; ++j) {
+                if (cgraph->nodes[i]->opt[j]) {
+                    ggml_graph_export_node(cgraph->nodes[i]->opt[j], "OPT", fout);
+                }
+            }
+
+            fprintf(fout, "\n");
+        }
+
+        fprintf(fout, "\n");
+    }
+
+    // write binary data
+    {
+        FILE * fout = fopen(fname, "wb");
+
+        if (!fout) {
+            fprintf(stderr, "%s: failed to open %s\n", __func__, fname);
+            return;
+        }
+
+        // header
+        {
+            const uint32_t magic   = GGML_FILE_MAGIC;
+            const uint32_t version = GGML_FILE_VERSION;
+            const uint32_t n_leafs = cgraph->n_leafs;
+            const uint32_t nodes   = cgraph->n_nodes;
+
+            fwrite(&magic,     sizeof(uint32_t), 1, fout);
+            fwrite(&version,   sizeof(uint32_t), 1, fout);
+            fwrite(&n_leafs,   sizeof(uint32_t), 1, fout);
+            fwrite(&nodes,     sizeof(uint32_t), 1, fout);
+            fwrite(&size_eval, sizeof(uint64_t), 1, fout);
+        }
+
+        // leafs
+        {
+            for (int i = 0; i < cgraph->n_leafs; ++i) {
+                const struct ggml_tensor * tensor = cgraph->leafs[i];
+
+                const uint32_t type   = tensor->type;
+                const uint32_t op     = tensor->op;
+                const uint32_t n_dims = tensor->n_dims;
+
+                fwrite(&type,   sizeof(uint32_t), 1, fout);
+                fwrite(&op,     sizeof(uint32_t), 1, fout);
+                fwrite(&n_dims, sizeof(uint32_t), 1, fout);
+
+                for (int j = 0; j < GGML_MAX_DIMS; ++j) {
+                    const uint64_t ne = tensor->ne[j];
+                    const uint64_t nb = tensor->nb[j];
+
+                    fwrite(&ne, sizeof(uint64_t), 1, fout);
+                    fwrite(&nb, sizeof(uint64_t), 1, fout);
+                }
+
+                // store the pointer address
+                {
+                    const uint64_t ptr = (uint64_t) tensor->data;
+
+                    fwrite(&ptr, sizeof(uint64_t), 1, fout);
+                }
+
+                fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout);
+
+                // dump the data
+                // TODO: pad this to 32 byte boundary
+                {
+                    const size_t size = ggml_nbytes(tensor);
+
+                    fwrite(tensor->data, sizeof(char), size, fout);
+                }
+            }
+        }
+
+        // nodes
+        {
+            for (int i = 0; i < cgraph->n_nodes; ++i) {
+                const struct ggml_tensor * tensor = cgraph->nodes[i];
+
+                const uint32_t type   = tensor->type;
+                const uint32_t op     = tensor->op;
+                const uint32_t n_dims = tensor->n_dims;
+
+                fwrite(&type,   sizeof(uint32_t), 1, fout);
+                fwrite(&op,     sizeof(uint32_t), 1, fout);
+                fwrite(&n_dims, sizeof(uint32_t), 1, fout);
+
+                for (int j = 0; j < GGML_MAX_DIMS; ++j) {
+                    const uint64_t ne = tensor->ne[j];
+                    const uint64_t nb = tensor->nb[j];
+
+                    fwrite(&ne, sizeof(uint64_t), 1, fout);
+                    fwrite(&nb, sizeof(uint64_t), 1, fout);
+                }
+
+                // store the pointer address
+                {
+                    const uint64_t ptr = (uint64_t) tensor->data;
+
+                    fwrite(&ptr, sizeof(uint64_t), 1, fout);
+                }
+
+                fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout);
+
+                // output the op arguments
+                {
+                    struct ggml_tensor * args[2 + GGML_MAX_OPT] = { NULL };
+
+                    args[0] = tensor->src0;
+                    args[1] = tensor->src1;
+
+                    for (int j = 0; j < GGML_MAX_OPT; ++j) {
+                        args[2 + j] = tensor->opt[j];
+                    }
+
+                    for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) {
+                        if (args[j]) {
+                            int32_t idx = -1;
+
+                            // check if leaf
+                            {
+                                for (int k = 0; k < cgraph->n_leafs; ++k) {
+                                    if (args[j] == cgraph->leafs[k]) {
+                                        idx = k;
+                                        break;
+                                    }
+                                }
+                            }
+
+                            // check if node
+                            if (idx == -1) {
+                                for (int k = 0; k < cgraph->n_nodes; ++k) {
+                                    if (args[j] == cgraph->nodes[k]) {
+                                        idx = GGML_MAX_NODES + k;
+                                        break;
+                                    }
+                                }
+                            }
+
+                            if (idx == -1) {
+                                fprintf(stderr, "%s: failed to find tensor, arg = %d, node = %d\n", __func__, j, i);
+                                return;
+                            }
+
+                            fwrite(&idx, sizeof(int32_t), 1, fout);
+                        } else {
+                            const int32_t nul = -1;
+
+                            fwrite(&nul, sizeof(int32_t), 1, fout);
+                        }
+                    }
+                }
+            }
+        }
+
+        fclose(fout);
+    }
+}
+
+struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval) {
+    assert(*ctx_data == NULL);
+    assert(*ctx_eval == NULL);
+
+    struct ggml_cgraph result = { 0 };
 
-            while (atomic_load(&state_shared.n_ready) > 0) {
-                ggml_lock_lock  (&state_shared.spin);
-                ggml_lock_unlock(&state_shared.spin);
-            }
+    struct ggml_tensor * data = NULL;
 
-            atomic_store(&state_shared.has_work, true);
+    // read file into data
+    {
+        FILE * fin = fopen(fname, "rb");
+        if (!fin) {
+            fprintf(stderr, "%s: failed to open %s\n", __func__, fname);
+            return result;
         }
 
-        params.type = GGML_TASK_COMPUTE;
-        ggml_compute_forward(&params, node);
+        size_t fsize = 0;
 
-        // wait for thread pool
-        if (node->n_tasks > 1) {
-            if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
-                atomic_store(&state_shared.has_work, false);
-            }
+        fseek(fin, 0, SEEK_END);
+        fsize = ftell(fin);
+        fseek(fin, 0, SEEK_SET);
 
-            while (atomic_load(&state_shared.has_work)) {
-                ggml_lock_lock  (&state_shared.spin);
-                ggml_lock_unlock(&state_shared.spin);
-            }
+        // create the data context
+        {
+            const size_t overhead = 1*ggml_tensor_overhead();
 
-            atomic_fetch_sub(&state_shared.n_ready, 1);
+            struct ggml_init_params params = {
+                .mem_size   = fsize + overhead,
+                .mem_buffer = NULL,
+                .no_alloc   = false,
+            };
 
-            while (atomic_load(&state_shared.n_ready) != 0) {
-                ggml_lock_lock  (&state_shared.spin);
-                ggml_lock_unlock(&state_shared.spin);
+            *ctx_data = ggml_init(params);
+
+            if (!*ctx_data) {
+                fprintf(stderr, "%s: failed to create ggml context\n", __func__);
+                fclose(fin);
+                return result;
             }
         }
 
-        // FINALIZE
-        if (node->n_tasks > 1) {
-            if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
-                atomic_store(&state_shared.has_work, false);
-            }
+        data = ggml_new_tensor_1d(*ctx_data, GGML_TYPE_I8, fsize);
 
-            while (atomic_load(&state_shared.has_work)) {
-                ggml_lock_lock  (&state_shared.spin);
-                ggml_lock_unlock(&state_shared.spin);
+        {
+            const size_t ret = fread(data->data, sizeof(char), fsize, fin);
+            if (ret != fsize) {
+                fprintf(stderr, "%s: failed to read %s\n", __func__, fname);
+                fclose(fin);
+                return result;
             }
+        }
 
-            // launch thread pool
-            for (int j = 0; j < n_threads - 1; j++) {
-                workers[j].params = (struct ggml_compute_params) {
-                    .type  = GGML_TASK_FINALIZE,
-                    .ith   = j + 1,
-                    .nth   = node->n_tasks,
-                    .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
-                    .wdata = cgraph->work ? cgraph->work->data : NULL,
-                };
-                workers[j].node = node;
-            }
+        fclose(fin);
+    }
 
-            atomic_fetch_sub(&state_shared.n_ready, 1);
+    // populate result
+    {
+        char * ptr = (char *) data->data;
 
-            while (atomic_load(&state_shared.n_ready) > 0) {
-                ggml_lock_lock  (&state_shared.spin);
-                ggml_lock_unlock(&state_shared.spin);
-            }
+        const uint32_t magic = *(const uint32_t *) ptr; ptr += sizeof(magic);
 
-            atomic_store(&state_shared.has_work, true);
+        if (magic != GGML_FILE_MAGIC) {
+            fprintf(stderr, "%s: invalid magic number, got %08x\n", __func__, magic);
+            return result;
         }
 
-        params.type = GGML_TASK_FINALIZE;
-        ggml_compute_forward(&params, node);
+        const uint32_t version = *(const uint32_t *) ptr; ptr += sizeof(version);
 
-        // wait for thread pool
-        if (node->n_tasks > 1) {
-            if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
-                atomic_store(&state_shared.has_work, false);
-            }
+        if (version != GGML_FILE_VERSION) {
+            fprintf(stderr, "%s: invalid version number\n", __func__);
+            return result;
+        }
 
-            while (atomic_load(&state_shared.has_work)) {
-                ggml_lock_lock  (&state_shared.spin);
-                ggml_lock_unlock(&state_shared.spin);
-            }
+        const uint32_t n_leafs   = *(const uint32_t *) ptr; ptr += sizeof(n_leafs);
+        const uint32_t n_nodes   = *(const uint32_t *) ptr; ptr += sizeof(n_nodes);
+        const uint64_t size_eval = *(const uint64_t *) ptr; ptr += sizeof(size_eval);
 
-            atomic_fetch_sub(&state_shared.n_ready, 1);
+        result.n_leafs = n_leafs;
+        result.n_nodes = n_nodes;
 
-            while (atomic_load(&state_shared.n_ready) != 0) {
-                ggml_lock_lock  (&state_shared.spin);
-                ggml_lock_unlock(&state_shared.spin);
+        // create the data context
+        {
+            const size_t overhead = (n_leafs + n_nodes)*ggml_tensor_overhead();
+
+            struct ggml_init_params params = {
+                .mem_size   = size_eval + overhead,
+                .mem_buffer = NULL,
+                .no_alloc   = true,
+            };
+
+            *ctx_eval = ggml_init(params);
+
+            if (!*ctx_eval) {
+                fprintf(stderr, "%s: failed to create ggml context\n", __func__);
+                return result;
             }
         }
 
-        // performance stats (node)
+        // leafs
         {
-            int64_t perf_cycles_cur  = ggml_perf_cycles()  - perf_node_start_cycles;
-            int64_t perf_time_us_cur = ggml_perf_time_us() - perf_node_start_time_us;
+            uint32_t type;
+            uint32_t op;
+            uint32_t n_dims;
 
-            node->perf_runs++;
-            node->perf_cycles  += perf_cycles_cur;
-            node->perf_time_us += perf_time_us_cur;
-        }
-    }
+            for (uint32_t i = 0; i < n_leafs; ++i) {
+                type   = *(const uint32_t *) ptr; ptr += sizeof(type);
+                op     = *(const uint32_t *) ptr; ptr += sizeof(op);
+                n_dims = *(const uint32_t *) ptr; ptr += sizeof(n_dims);
 
-    // join thread pool
-    if (n_threads > 1) {
-        atomic_store(&state_shared.stop, true);
-        atomic_store(&state_shared.has_work, true);
+                int64_t ne[GGML_MAX_DIMS];
+                size_t  nb[GGML_MAX_DIMS];
 
-        for (int j = 0; j < n_threads - 1; j++) {
-            int rc = ggml_thread_join(workers[j].thrd, NULL);
-            GGML_ASSERT(rc == 0);
-            UNUSED(rc);
+                for (int j = 0; j < GGML_MAX_DIMS; ++j) {
+                    uint64_t ne_cur;
+                    uint64_t nb_cur;
+
+                    ne_cur = *(const uint64_t *) ptr; ptr += sizeof(ne_cur);
+                    nb_cur = *(const uint64_t *) ptr; ptr += sizeof(nb_cur);
+
+                    ne[j] = ne_cur;
+                    nb[j] = nb_cur;
+                }
+
+                struct ggml_tensor * tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne);
+
+                tensor->op = (enum ggml_op) op;
+
+                uint64_t ptr_cur = *(const uint64_t *) ptr; ptr += sizeof(ptr_cur);
+
+                memcpy(tensor->name, ptr, GGML_MAX_NAME); ptr += GGML_MAX_NAME;
+
+                tensor->data = (void *) ptr;
+
+                for (int j = 0; j < GGML_MAX_DIMS; ++j) {
+                    tensor->nb[j] = nb[j];
+                }
+
+                result.leafs[i] = tensor;
+
+                ptr += ggml_nbytes(tensor);
+
+                fprintf(stderr, "%s: loaded leaf %d: '%16s', %3d dims, %9zu bytes\n", __func__, i, tensor->name, n_dims, ggml_nbytes(tensor));
+            }
         }
 
-        ggml_lock_destroy(&state_shared.spin);
-    }
+        ggml_set_no_alloc(*ctx_eval, false);
 
-    // performance stats (graph)
-    {
-        int64_t perf_cycles_cur  = ggml_perf_cycles()  - perf_start_cycles;
-        int64_t perf_time_us_cur = ggml_perf_time_us() - perf_start_time_us;
+        // nodes
+        {
+            uint32_t type;
+            uint32_t op;
+            uint32_t n_dims;
 
-        cgraph->perf_runs++;
-        cgraph->perf_cycles  += perf_cycles_cur;
-        cgraph->perf_time_us += perf_time_us_cur;
+            for (uint32_t i = 0; i < n_nodes; ++i) {
+                type   = *(const uint32_t *) ptr; ptr += sizeof(type);
+                op     = *(const uint32_t *) ptr; ptr += sizeof(op);
+                n_dims = *(const uint32_t *) ptr; ptr += sizeof(n_dims);
 
-        GGML_PRINT_DEBUG("%s: perf (%d) - cpu = %.3f / %.3f ms, wall = %.3f / %.3f ms\n",
-                __func__, cgraph->perf_runs,
-                (double) perf_cycles_cur      / (double) ggml_cycles_per_ms(),
-                (double) cgraph->perf_cycles  / (double) ggml_cycles_per_ms() / (double) cgraph->perf_runs,
-                (double) perf_time_us_cur     / 1000.0,
-                (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs);
-    }
-}
+                enum ggml_op eop = (enum ggml_op) op;
 
-void ggml_graph_reset(struct ggml_cgraph * cgraph) {
-    for (int i = 0; i < cgraph->n_nodes; i++) {
-        struct ggml_tensor * grad = cgraph->grads[i];
+                int64_t ne[GGML_MAX_DIMS];
+                size_t  nb[GGML_MAX_DIMS];
 
-        if (grad) {
-            ggml_set_zero(grad);
+                for (int j = 0; j < GGML_MAX_DIMS; ++j) {
+                    uint64_t ne_cur;
+                    uint64_t nb_cur;
+
+                    ne_cur = *(const uint64_t *) ptr; ptr += sizeof(ne_cur);
+                    nb_cur = *(const uint64_t *) ptr; ptr += sizeof(nb_cur);
+
+                    ne[j] = ne_cur;
+                    nb[j] = nb_cur;
+                }
+
+                uint64_t ptr_cur = *(const uint64_t *) ptr; ptr += sizeof(ptr_cur); // TODO: not yet used
+
+                const char * ptr_name = ptr; ptr += GGML_MAX_NAME;
+
+                const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += (2 + GGML_MAX_OPT)*sizeof(int32_t);
+
+                struct ggml_tensor * args[2 + GGML_MAX_OPT] = { NULL };
+
+                // parse args
+                for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) {
+                    const int32_t arg_idx = ptr_arg_idx[j];
+
+                    if (arg_idx == -1) {
+                        continue;
+                    }
+
+                    if (arg_idx < GGML_MAX_NODES) {
+                        args[j] = result.leafs[arg_idx];
+                    } else {
+                        args[j] = result.nodes[arg_idx - GGML_MAX_NODES];
+                    }
+                }
+
+                // create the tensor
+                // "view" operations are handled differently
+                // TODO: handle inplace ops - currently a copy is always made
+
+                struct ggml_tensor * tensor = NULL;
+
+                switch (eop) {
+                    // TODO: implement other view ops
+                    case GGML_OP_RESHAPE:
+                        {
+                            tensor = ggml_reshape_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3]);
+                        } break;
+                    case GGML_OP_VIEW:
+                        {
+                            tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0);
+
+                            uint64_t offs;
+                            memcpy(&offs, args[2]->data, sizeof(offs));
+
+                            tensor->data = ((char *) tensor->data) + offs;
+                        } break;
+                    case GGML_OP_TRANSPOSE:
+                        {
+                            tensor = ggml_transpose(*ctx_eval, args[0]);
+                        } break;
+                    case GGML_OP_PERMUTE:
+                        {
+                            tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0);
+                        } break;
+                    default:
+                        {
+                            tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne);
+
+                            tensor->op = eop;
+                        } break;
+                }
+
+                memcpy(tensor->name, ptr_name, GGML_MAX_NAME);
+
+                for (int j = 0; j < GGML_MAX_DIMS; ++j) {
+                    tensor->nb[j] = nb[j];
+                }
+
+                tensor->src0 = args[0];
+                tensor->src1 = args[1];
+
+                for (int j = 0; j < GGML_MAX_OPT; ++j) {
+                    tensor->opt[j] = args[2 + j];
+                }
+
+                result.nodes[i] = tensor;
+
+                fprintf(stderr, "%s: loaded node %d: '%16s', %3d dims, %9zu bytes\n", __func__, i, tensor->name, n_dims, ggml_nbytes(tensor));
+            }
         }
     }
+
+    return result;
 }
 
 void ggml_graph_print(const struct ggml_cgraph * cgraph) {
@@ -14538,7 +17682,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
         GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
                 i,
                 node->ne[0], node->ne[1], node->ne[2],
-                GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
+                GGML_OP_NAME[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
                 (double) node->perf_cycles  / (double) ggml_cycles_per_ms(),
                 (double) node->perf_cycles  / (double) ggml_cycles_per_ms() / (double) node->perf_runs,
                 (double) node->perf_time_us / 1000.0,
@@ -14552,7 +17696,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
         GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
                 i,
                 node->ne[0], node->ne[1],
-                GGML_OP_LABEL[node->op]);
+                GGML_OP_NAME[node->op]);
     }
 
     for (int i = 0; i < GGML_OP_COUNT; i++) {
@@ -14560,7 +17704,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
             continue;
         }
 
-        GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_LABEL[i], (double) perf_total_per_op_us[i] / 1000.0);
+        GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_NAME[i], (double) perf_total_per_op_us[i] / 1000.0);
     }
 
     GGML_PRINT("========================================\n");
@@ -14593,6 +17737,26 @@ static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgr
     return NULL;
 }
 
+static void ggml_graph_dump_dot_node_edge(FILE * fp, const struct ggml_cgraph * gb, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label)  {
+    struct ggml_tensor * gparent = ggml_graph_get_parent(gb, node);
+    struct ggml_tensor * gparent0 = ggml_graph_get_parent(gb, parent);
+    fprintf(fp, "  \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"%s\"; ]\n",
+            gparent0 ? (void *) gparent0 : (void *) parent,
+            gparent0 ? "g" : "x",
+            gparent ? (void *) gparent : (void *) node,
+            gparent ? "g" : "x",
+            gparent ? "empty" : "vee",
+            gparent ? "dashed" : "solid",
+            label);
+}
+
+static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label)  {
+    fprintf(fp, "  \"%p\":%s -> \"%p\":%s [ label = \"%s\"; ]\n",
+            (void *) parent, "x",
+            (void *) node, "x",
+            label);
+}
+
 void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) {
     char color[16];
 
@@ -14628,12 +17792,16 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
                 (void *) node, color);
 
         if (strlen(node->name) > 0) {
-            fprintf(fp, "%s |", node->name);
+            fprintf(fp, "%s (%s)|", node->name, ggml_type_name(node->type));
+        } else {
+            fprintf(fp, "(%s)|", ggml_type_name(node->type));
         }
 
-        fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | <x>%s",
-                i, node->ne[0], node->ne[1],
-                GGML_OP_SYMBOL[node->op]);
+        if (node->n_dims == 2) {
+            fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], GGML_OP_SYMBOL[node->op]);
+        } else {
+            fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], node->ne[2], GGML_OP_SYMBOL[node->op]);
+        }
 
         if (node->grad) {
             fprintf(fp, " | <g>%s\"; ]\n", GGML_OP_SYMBOL[node->grad->op]);
@@ -14653,18 +17821,29 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
                 (void *) node, color);
 
         if (strlen(node->name) > 0) {
-                fprintf(fp, "%s | ", node->name);
+            fprintf(fp, "%s (%s)|", node->name, ggml_type_name(node->type));
+        } else {
+            fprintf(fp, "(%s)|", ggml_type_name(node->type));
         }
-        if (ggml_nelements(node) == 1) {
-            if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {
-                fprintf(fp, "%d", ggml_get_i32_1d(node, 0));
-            }
-            else {
-                fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, 0));
+
+        fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]);
+        if (ggml_nelements(node) < 5) {
+            fprintf(fp, " | (");
+            for (int j = 0; j < ggml_nelements(node); j++) {
+                if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {
+                    fprintf(fp, "%d", ggml_get_i32_1d(node, j));
+                }
+                else if (node->type == GGML_TYPE_F32 || node->type == GGML_TYPE_F16) {
+                    fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, j));
+                }
+                else {
+                    fprintf(fp, "#");
+                }
+                if (j < ggml_nelements(node) - 1) {
+                    fprintf(fp, ", ");
+                }
             }
-        }
-        else {
-            fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]);
+            fprintf(fp, ")");
         }
         fprintf(fp, "\"; ]\n");
     }
@@ -14672,30 +17851,20 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
     for (int i = 0; i < gb->n_nodes; i++) {
         struct ggml_tensor * node = gb->nodes[i];
 
-        struct ggml_tensor * parent = ggml_graph_get_parent(gb, node);
-
         if (node->src0) {
-            struct ggml_tensor * parent0 = ggml_graph_get_parent(gb, node->src0);
-
-            fprintf(fp, "  \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"x\"; ]\n",
-                    parent0 ? (void *) parent0 : (void *) node->src0,
-                    parent0 ? "g" : "x",
-                    parent ? (void *) parent : (void *) node,
-                    parent ? "g" : "x",
-                    parent ? "empty" : "vee",
-                    parent ? "dashed" : "solid");
+            ggml_graph_dump_dot_node_edge(fp, gb, node, node->src0, "x");
         }
 
         if (node->src1) {
-            struct ggml_tensor * parent1 = ggml_graph_get_parent(gb, node->src1);
+            ggml_graph_dump_dot_node_edge(fp, gb, node, node->src1, "y");
+        }
 
-            fprintf(fp, "  \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"y\"; ]\n",
-                    parent1 ? (void *) parent1 : (void *) node->src1,
-                    parent1 ? "g" : "x",
-                    parent ? (void *) parent : (void *) node,
-                    parent ? "g" : "x",
-                    parent ? "empty" : "vee",
-                    parent ? "dashed" : "solid");
+        for (int j = 0; j < GGML_MAX_OPT; j++) {
+            if (node->opt[j]) {
+                char label[16];
+                snprintf(label, sizeof(label), "opt %d", j);
+                ggml_graph_dump_dot_node_edge(fp, gb, node, node->opt[j], label);
+            }
         }
     }
 
@@ -14703,15 +17872,19 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
         struct ggml_tensor * node = gb->leafs[i];
 
         if (node->src0) {
-            fprintf(fp, "  \"%p\":%s -> \"%p\":%s [ label = \"x\"; ]\n",
-                    (void *) node->src0, "x",
-                    (void *) node, "x");
+            ggml_graph_dump_dot_leaf_edge(fp, node, node->src0, "x");
         }
 
         if (node->src1) {
-            fprintf(fp, "  \"%p\":%s -> \"%p\":%s [ label = \"y\"; ]\n",
-                    (void *) node->src1, "x",
-                    (void *) node, "x");
+            ggml_graph_dump_dot_leaf_edge(fp, node, node->src1, "y");
+        }
+
+        for (int j = 0; j < GGML_MAX_OPT; j++) {
+            if (node->opt[j]) {
+                char label[16];
+                snprintf(label, sizeof(label), "opt %d", j);
+                ggml_graph_dump_dot_leaf_edge(fp, node, node->opt[j], label);
+            }
         }
     }
 
@@ -14765,6 +17938,7 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
 
 static enum ggml_opt_result ggml_opt_adam(
         struct ggml_context * ctx,
+        struct ggml_opt_context * opt,
         struct ggml_opt_params params,
         struct ggml_tensor * f,
         struct ggml_cgraph * gf,
@@ -14790,25 +17964,29 @@ static enum ggml_opt_result ggml_opt_adam(
         }
     }
 
+    if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past)) {
+        int iter = opt->iter;
+        ggml_opt_init(opt->ctx, opt, params, nx);
+        opt->iter = iter;
+    }
+
     // constants
-    const float alpha = params.adam.alpha;
+    const float sched = params.adam.sched;
+    const float decay = params.adam.decay * sched;
+    const float alpha = params.adam.alpha * sched;
     const float beta1 = params.adam.beta1;
     const float beta2 = params.adam.beta2;
     const float eps   = params.adam.eps;
 
-    float * x  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // view of the parameters
-    float * g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient
-    float * g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient squared
-    float * m  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment
-    float * v  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment
-    float * mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment hat
-    float * vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment hat
-
-    float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values
+    float * x  = opt->adam.x->data;  // view of the parameters
+    float * g1 = opt->adam.g1->data; // gradient
+    float * g2 = opt->adam.g2->data; // gradient squared
+    float * m  = opt->adam.m->data;  // first moment
+    float * v  = opt->adam.v->data;  // second moment
+    float * mh = opt->adam.mh->data; // first moment hat
+    float * vh = opt->adam.vh->data; // second moment hat
 
-    // initialize
-    ggml_vec_set_f32(nx, m, 0.0f);
-    ggml_vec_set_f32(nx, v, 0.0f);
+    float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
 
     // update view
     ggml_opt_get_params(np, ps, x);
@@ -14818,16 +17996,27 @@ static enum ggml_opt_result ggml_opt_adam(
     ggml_set_f32      (f->grad, 1.0f);
     ggml_graph_compute(ctx, gb);
 
-    float fx_prev = ggml_get_f32_1d(f, 0);
+    opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
+    opt->adam.fx_best = opt->adam.fx_prev;
     if (pf) {
-        pf[0] = fx_prev;
+        pf[opt->iter % params.past] = opt->adam.fx_prev;
+    }
+
+    // initialize
+    if (opt->just_initialized) {
+        opt->adam.n_no_improvement = 0;
+        opt->just_initialized = false;
     }
 
-    int n_no_improvement = 0;
-    float fx_best = fx_prev;
+    float * fx_best = &opt->adam.fx_best;
+    float * fx_prev = &opt->adam.fx_prev;
+    int * n_no_improvement = &opt->adam.n_no_improvement;
+
+    int iter0 = opt->iter;
 
     // run the optimizer
     for (int t = 0; t < params.adam.n_iter; ++t) {
+        opt->iter = iter0 + t + 1;
         GGML_PRINT_DEBUG  ("=== iter %d ===\n", t);
 
         GGML_PRINT_DEBUG  ("f      = %10.6f\n", ggml_get_f32_1d(f, 0));
@@ -14861,17 +18050,22 @@ static enum ggml_opt_result ggml_opt_adam(
 
             // m^hat = m_t / (1 - beta1^t)
             // v^hat = v_t / (1 - beta2^t)
-            // x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps)
+            // x_t = x_t-1 - sched*(alpha*m^hat/(sqrt(v^hat) + eps) + decay*x_t-1)
+            // x_t = x_t-1 - sched*alpha*m^hat/(sqrt(v^hat) + eps) - sched*decay*x_t-1
+            // x_t = x_t-1*(1-sched*decay) - sched*alpha*m^hat/(sqrt(v^hat) + eps)
+            // x_t = x_t-1*(1-sched*decay) + sched*decay*(-alpha/decay)*m^hat/(sqrt(v^hat) + eps)
+            // x_t = mix(x_t-1, (-alpha/decay)*m^hat/(sqrt(v^hat) + eps), sched*decay)
             ggml_vec_cpy_f32  (nx, mh, m);
             ggml_vec_cpy_f32  (nx, vh, v);
 
-            ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, t + 1)));
-            ggml_vec_scale_f32(nx, vh,  1.0f/(1.0f - powf(beta2, t + 1)));
+            ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, opt->iter)));
+            ggml_vec_scale_f32(nx, vh,  1.0f/(1.0f - powf(beta2, opt->iter)));
 
             ggml_vec_sqrt_f32 (nx, vh, vh);
             ggml_vec_acc1_f32 (nx, vh, eps);
 
             ggml_vec_div_f32  (nx, mh, mh, vh);
+            ggml_vec_scale_f32(nx, x,  1.0f - decay);
             ggml_vec_sub_f32  (nx, x,  x,  mh);
 
             // update the parameters
@@ -14885,7 +18079,7 @@ static enum ggml_opt_result ggml_opt_adam(
         const float fx = ggml_get_f32_1d(f, 0);
 
         // check convergence
-        if (fabsf(fx - fx_prev)/fx < params.adam.eps_f) {
+        if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) {
             GGML_PRINT_DEBUG("converged\n");
 
             return GGML_OPT_OK;
@@ -14894,32 +18088,32 @@ static enum ggml_opt_result ggml_opt_adam(
         // delta-based convergence test
         if (pf != NULL) {
             // need at least params.past iterations to start checking for convergence
-            if (params.past <= t) {
-                const float rate = (pf[t%params.past] - fx)/fx;
+            if (params.past <= iter0 + t) {
+                const float rate = (pf[(iter0 + t)%params.past] - fx)/fx;
 
                 if (fabsf(rate) < params.delta) {
                     return GGML_OPT_OK;
                 }
             }
 
-            pf[t%params.past] = fx;
+            pf[(iter0 + t)%params.past] = fx;
         }
 
         // check for improvement
         if (params.max_no_improvement > 0) {
-            if (fx_best > fx) {
-                fx_best = fx;
-                n_no_improvement = 0;
+            if (fx_best[0] > fx) {
+                fx_best[0] = fx;
+                n_no_improvement[0] = 0;
             } else {
-                ++n_no_improvement;
+                ++n_no_improvement[0];
 
-                if (n_no_improvement >= params.max_no_improvement) {
+                if (n_no_improvement[0] >= params.max_no_improvement) {
                     return GGML_OPT_OK;
                 }
             }
         }
 
-        fx_prev = fx;
+        fx_prev[0] = fx;
 
         {
             const int64_t t_end_cpu = ggml_cycles();
@@ -15058,6 +18252,7 @@ static enum ggml_opt_result linesearch_backtracking(
 
 static enum ggml_opt_result ggml_opt_lbfgs(
         struct ggml_context * ctx,
+        struct ggml_opt_context * opt,
         struct ggml_opt_params params,
         struct ggml_tensor * f,
         struct ggml_cgraph * gf,
@@ -15090,31 +18285,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
         }
     }
 
-    float * x  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current parameters
-    float * xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous parameters
-    float * g  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current gradient
-    float * gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous gradient
-    float * d  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // search direction
+    if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past) || (opt->params.lbfgs.m != params.lbfgs.m)) {
+        int iter = opt->iter;
+        ggml_opt_init(ctx, opt, params, nx);
+        opt->iter = iter;
+    }
+
+    float * x  = opt->lbfgs.x->data;  // current parameters
+    float * xp = opt->lbfgs.xp->data; // previous parameters
+    float * g  = opt->lbfgs.g->data;  // current gradient
+    float * gp = opt->lbfgs.gp->data; // previous gradient
+    float * d  = opt->lbfgs.d->data;  // search direction
 
-    float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values
+    float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
 
     float fx    = 0.0f; // cost function value
     float xnorm = 0.0f; // ||x||
     float gnorm = 0.0f; // ||g||
-    float step  = 0.0f;
 
     // initialize x from the graph nodes
     ggml_opt_get_params(np, ps, x);
 
     // the L-BFGS memory
-    struct ggml_lbfgs_iteration_data * lm = alloca(sizeof(struct ggml_lbfgs_iteration_data)*m);
-
-    for (int i = 0; i < m; ++i) {
-        lm[i].alpha = 0.0f;
-        lm[i].ys    = 0.0f;
-        lm[i].s     = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
-        lm[i].y     = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
-    }
+    float * lm_alpha = opt->lbfgs.lmal->data;
+    float * lm_ys    = opt->lbfgs.lmys->data;
+    float * lm_s     = opt->lbfgs.lms->data;
+    float * lm_y     = opt->lbfgs.lmy->data;
 
     // evaluate the function value and its gradient
     {
@@ -15129,12 +18325,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
         fx = ggml_get_f32_1d(f, 0);
     }
 
-    if (pf) {
-        pf[0] = fx;
-    }
-
-    float fx_best = fx;
-
     // search direction = -gradient
     ggml_vec_neg_f32(nx, d, g);
 
@@ -15151,26 +18341,43 @@ static enum ggml_opt_result ggml_opt_lbfgs(
         return GGML_OPT_OK;
     }
 
-    // initial step
-    ggml_vec_norm_inv_f32(nx, &step, d);
+    if (opt->just_initialized) {
+        if (pf) {
+            pf[0] = fx;
+        }
+        opt->lbfgs.fx_best = fx;
+
+        // initial step
+        ggml_vec_norm_inv_f32(nx, &opt->lbfgs.step, d);
+        opt->lbfgs.j                = 0;
+        opt->lbfgs.k                = 1;
+        opt->lbfgs.end              = 0;
+        opt->lbfgs.n_no_improvement = 0;
+        opt->just_initialized       = false;
+    }
+
+    float * fx_best        = &opt->lbfgs.fx_best;
+    float * step           = &opt->lbfgs.step;
+    int * j                = &opt->lbfgs.j;
+    int * k                = &opt->lbfgs.k;
+    int * end              = &opt->lbfgs.end;
+    int * n_no_improvement = &opt->lbfgs.n_no_improvement;
 
-    int j                = 0;
-    int k                = 1;
-    int ls               = 0;
-    int end              = 0;
-    int bound            = 0;
-    int n_no_improvement = 0;
+    int ls     = 0;
+    int bound  = 0;
 
     float ys   = 0.0f;
     float yy   = 0.0f;
     float beta = 0.0f;
 
+    int it = 0;
+
     while (true) {
         // store the current position and gradient vectors
         ggml_vec_cpy_f32(nx, xp, x);
         ggml_vec_cpy_f32(nx, gp, g);
 
-        ls = linesearch_backtracking(ctx, &params, nx, x, &fx, g, d, &step, xp, f, gf, gb, np, ps);
+        ls = linesearch_backtracking(ctx, &params, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps);
 
         if (ls < 0) {
             // linesearch failed - go back to the previous point and return
@@ -15196,32 +18403,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
         // delta-based convergence test
         if (pf != NULL) {
             // need at least params.past iterations to start checking for convergence
-            if (params.past <= k) {
-                const float rate = (pf[k%params.past] - fx)/fx;
+            if (params.past <= k[0]) {
+                const float rate = (pf[k[0]%params.past] - fx)/fx;
 
                 if (fabsf(rate) < params.delta) {
                     return GGML_OPT_OK;
                 }
             }
 
-            pf[k%params.past] = fx;
+            pf[k[0]%params.past] = fx;
         }
 
         // check for improvement
         if (params.max_no_improvement > 0) {
-            if (fx < fx_best) {
-                fx_best = fx;
-                n_no_improvement = 0;
+            if (fx < fx_best[0]) {
+                fx_best[0] = fx;
+                n_no_improvement[0] = 0;
             } else {
-                n_no_improvement++;
+                n_no_improvement[0]++;
 
-                if (n_no_improvement >= params.max_no_improvement) {
+                if (n_no_improvement[0] >= params.max_no_improvement) {
                     return GGML_OPT_OK;
                 }
             }
         }
 
-        if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < k + 1) {
+        if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < it + 1) {
             // reached the maximum number of iterations
             return GGML_OPT_DID_NOT_CONVERGE;
         }
@@ -15230,50 +18437,51 @@ static enum ggml_opt_result ggml_opt_lbfgs(
         //   s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}.
         //   y_{k+1} = g_{k+1} - g_{k}.
         //
-        ggml_vec_sub_f32(nx, lm[end].s, x, xp);
-        ggml_vec_sub_f32(nx, lm[end].y, g, gp);
+        ggml_vec_sub_f32(nx, &lm_s[end[0]*nx], x, xp);
+        ggml_vec_sub_f32(nx, &lm_y[end[0]*nx], g, gp);
 
         // compute scalars ys and yy:
         //     ys = y^t \cdot s    -> 1 / \rho.
         //     yy = y^t \cdot y.
         //
-        ggml_vec_dot_f32(nx, &ys, lm[end].y, lm[end].s);
-        ggml_vec_dot_f32(nx, &yy, lm[end].y, lm[end].y);
+        ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0] *nx]);
+        ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]);
 
-        lm[end].ys = ys;
+        lm_ys[end[0]] = ys;
 
         // find new search direction
         //   ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS
 
-        bound = (m <= k) ? m : k;
-        k++;
-        end = (end + 1)%m;
+        bound = (m <= k[0]) ? m : k[0];
+        k[0]++;
+        it++;
+        end[0] = (end[0] + 1)%m;
 
         // initialize search direction with -g
         ggml_vec_neg_f32(nx, d, g);
 
-        j = end;
+        j[0] = end[0];
         for (int i = 0; i < bound; ++i) {
-            j = (j + m - 1) % m;
+            j[0] = (j[0] + m - 1) % m;
             // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1}
-            ggml_vec_dot_f32(nx, &lm[j].alpha, lm[j].s, d);
-            lm[j].alpha /= lm[j].ys;
+            ggml_vec_dot_f32(nx, &lm_alpha[j[0]], &lm_s[j[0]*nx], d);
+            lm_alpha[j[0]] /= lm_ys[j[0]];
             // q_{i} = q_{i+1} - \alpha_{i} y_{i}
-            ggml_vec_mad_f32(nx, d, lm[j].y, -lm[j].alpha);
+            ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]);
         }
 
         ggml_vec_scale_f32(nx, d, ys/yy);
 
         for (int i = 0; i < bound; ++i) {
             // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i}
-            ggml_vec_dot_f32(nx, &beta, lm[j].y, d);
-            beta /= lm[j].ys;
+            ggml_vec_dot_f32(nx, &beta, &lm_y[j[0]*nx], d);
+            beta /= lm_ys[j[0]];
             // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j}
-            ggml_vec_mad_f32(nx, d, lm[j].s, lm[j].alpha - beta);
-            j = (j + 1)%m;
+            ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta);
+            j[0] = (j[0] + 1)%m;
         }
 
-        step = 1.0;
+        step[0] = 1.0;
     }
 
     return GGML_OPT_DID_NOT_CONVERGE;
@@ -15298,6 +18506,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
 
                     .adam = {
                         .n_iter = 10000,
+                        .sched  = 1.000f,
+                        .decay  = 0.001f,
                         .alpha  = 0.001f,
                         .beta1  = 0.9f,
                         .beta2  = 0.999f,
@@ -15340,6 +18550,70 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
     return result;
 }
 
+GGML_API void ggml_opt_init(
+        struct ggml_context * ctx,
+        struct ggml_opt_context * opt,
+        struct ggml_opt_params params,
+        int64_t nx) {
+    opt->ctx = ctx;
+    opt->params = params;
+    opt->iter = 0;
+    opt->nx = nx;
+    opt->just_initialized = true;
+    switch (opt->params.type) {
+        case GGML_OPT_ADAM:
+            {
+                opt->adam.x  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+                opt->adam.g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+                opt->adam.g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+                opt->adam.m  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+                opt->adam.v  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+                opt->adam.mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+                opt->adam.vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+                opt->adam.pf = params.past > 0
+                    ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
+                    : NULL;
+                ggml_set_zero(opt->adam.x);
+                ggml_set_zero(opt->adam.g1);
+                ggml_set_zero(opt->adam.g2);
+                ggml_set_zero(opt->adam.m);
+                ggml_set_zero(opt->adam.v);
+                ggml_set_zero(opt->adam.mh);
+                ggml_set_zero(opt->adam.vh);
+                if (opt->adam.pf) {
+                    ggml_set_zero(opt->adam.pf);
+                }
+            } break;
+        case GGML_OPT_LBFGS:
+            {
+                opt->lbfgs.x  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+                opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+                opt->lbfgs.g  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+                opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+                opt->lbfgs.d  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+                opt->lbfgs.pf = params.past > 0
+                    ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
+                    : NULL;
+                opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
+                opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
+                opt->lbfgs.lms  = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
+                opt->lbfgs.lmy  = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
+                ggml_set_zero(opt->lbfgs.x);
+                ggml_set_zero(opt->lbfgs.xp);
+                ggml_set_zero(opt->lbfgs.g);
+                ggml_set_zero(opt->lbfgs.gp);
+                ggml_set_zero(opt->lbfgs.d);
+                if (opt->lbfgs.pf) {
+                    ggml_set_zero(opt->lbfgs.pf);
+                }
+                ggml_set_zero(opt->lbfgs.lmal);
+                ggml_set_zero(opt->lbfgs.lmys);
+                ggml_set_zero(opt->lbfgs.lms);
+                ggml_set_zero(opt->lbfgs.lmy);
+            } break;
+    }
+}
+
 enum ggml_opt_result ggml_opt(
         struct ggml_context * ctx,
         struct ggml_opt_params params,
@@ -15362,33 +18636,65 @@ enum ggml_opt_result ggml_opt(
 
     enum ggml_opt_result result = GGML_OPT_OK;
 
+    struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context));
+
+    ggml_opt_init(ctx, opt, params, 0);
+    result = ggml_opt_resume(ctx, opt, f);
+
+    if (free_ctx) {
+        ggml_free(ctx);
+    }
+
+    return result;
+}
+
+enum ggml_opt_result ggml_opt_resume(
+        struct ggml_context * ctx,
+        struct ggml_opt_context * opt,
+        struct ggml_tensor * f) {
+
+    // build forward + backward compute graphs
+    struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
+    struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
+
+    struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
+    struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
+
+    *gf = ggml_build_forward (f);
+    *gb = ggml_build_backward(ctx, gf, true);
+
+    return ggml_opt_resume_g(ctx, opt, f, gf, gb);
+}
+
+enum ggml_opt_result ggml_opt_resume_g(
+        struct ggml_context * ctx,
+        struct ggml_opt_context * opt,
+        struct ggml_tensor * f,
+        struct ggml_cgraph * gf,
+        struct ggml_cgraph * gb) {
+
     // build forward + backward compute graphs
-    struct ggml_cgraph gf = ggml_build_forward (f);
-    struct ggml_cgraph gb = ggml_build_backward(ctx, &gf, true);
+    enum ggml_opt_result result = GGML_OPT_OK;
 
-    switch (params.type) {
+    switch (opt->params.type) {
         case GGML_OPT_ADAM:
             {
-                result = ggml_opt_adam(ctx, params, f, &gf, &gb);
+                result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb);
             } break;
         case GGML_OPT_LBFGS:
             {
-                result = ggml_opt_lbfgs(ctx, params, f, &gf, &gb);
+                result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb);
             } break;
     }
 
-    if (params.print_forward_graph) {
-        ggml_graph_print   (&gf);
-        ggml_graph_dump_dot(&gf, NULL, "opt-forward.dot");
-    }
-
-    if (params.print_backward_graph) {
-        ggml_graph_print   (&gb);
-        ggml_graph_dump_dot(&gb, &gf, "opt-backward.dot");
+    if (opt->params.print_forward_graph) {
+        ggml_graph_print   (gf);
+        ggml_graph_dump_dot(gf, NULL, "opt-forward.dot");
     }
 
-    if (free_ctx) {
-        ggml_free(ctx);
+    if (opt->params.print_backward_graph) {
+        ggml_graph_print   (gb);
+        ggml_graph_dump_dot(gb, gf, "opt-backward.dot");
     }
 
     return result;
@@ -15556,6 +18862,50 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
                 block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
                 result = ggml_quantize_q8_0(src + start, block, n, n, hist);
             } break;
+#ifdef GGML_USE_K_QUANTS
+        case GGML_TYPE_Q2_K:
+            {
+                GGML_ASSERT(start % QK_K == 0);
+                block_q2_K * block = (block_q2_K*)dst + start / QK_K;
+                result = ggml_quantize_q2_K(src + start, block, n, n, hist);
+            } break;
+        case GGML_TYPE_Q3_K:
+            {
+                GGML_ASSERT(start % QK_K == 0);
+                block_q3_K * block = (block_q3_K*)dst + start / QK_K;
+                result = ggml_quantize_q3_K(src + start, block, n, n, hist);
+            } break;
+        case GGML_TYPE_Q4_K:
+            {
+                GGML_ASSERT(start % QK_K == 0);
+                block_q4_K * block = (block_q4_K*)dst + start / QK_K;
+                result = ggml_quantize_q4_K(src + start, block, n, n, hist);
+            } break;
+        case GGML_TYPE_Q5_K:
+            {
+                GGML_ASSERT(start % QK_K == 0);
+                block_q5_K * block = (block_q5_K*)dst + start / QK_K;
+                result = ggml_quantize_q5_K(src + start, block, n, n, hist);
+            } break;
+        case GGML_TYPE_Q6_K:
+            {
+                GGML_ASSERT(start % QK_K == 0);
+                block_q6_K * block = (block_q6_K*)dst + start / QK_K;
+                result = ggml_quantize_q6_K(src + start, block, n, n, hist);
+            } break;
+#endif
+        case GGML_TYPE_F16:
+            {
+                int elemsize = sizeof(ggml_fp16_t);
+                ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
+                result = n * elemsize;
+            } break;
+        case GGML_TYPE_F32:
+            {
+                int elemsize = sizeof(float);
+                result = n * elemsize;
+                memcpy((uint8_t *)dst + start * elemsize, src + start, result);
+            } break;
         default:
             assert(false);
     }
diff --git a/ggml.h b/ggml.h
index 51a616c501bb3eb46bf2b20c727b0c0a0b7a16dc..5ebd9c46c3d5267df4e2a879cfcb633bd23c6085 100644 (file)
--- a/ggml.h
+++ b/ggml.h
 #define GGML_MAX_PARAMS        256
 #define GGML_MAX_CONTEXTS      64
 #define GGML_MAX_OPT           4
+#define GGML_MAX_NAME          32
 #define GGML_DEFAULT_N_THREADS 4
 
 #define GGML_ASSERT(x) \
@@ -240,6 +241,13 @@ extern "C" {
         GGML_TYPE_Q5_1 = 7,
         GGML_TYPE_Q8_0 = 8,
         GGML_TYPE_Q8_1 = 9,
+        // k-quantizations
+        GGML_TYPE_Q2_K = 10,
+        GGML_TYPE_Q3_K = 11,
+        GGML_TYPE_Q4_K = 12,
+        GGML_TYPE_Q5_K = 13,
+        GGML_TYPE_Q6_K = 14,
+        GGML_TYPE_Q8_K = 15,
         GGML_TYPE_I8,
         GGML_TYPE_I16,
         GGML_TYPE_I32,
@@ -248,7 +256,8 @@ extern "C" {
 
     enum ggml_backend {
         GGML_BACKEND_CPU = 0,
-        GGML_BACKEND_CUDA = 1,
+        GGML_BACKEND_GPU = 10,
+        GGML_BACKEND_GPU_SPLIT = 20,
     };
 
     // model file types
@@ -262,6 +271,11 @@ extern "C" {
         GGML_FTYPE_MOSTLY_Q8_0 = 7,  // except 1d tensors
         GGML_FTYPE_MOSTLY_Q5_0 = 8,  // except 1d tensors
         GGML_FTYPE_MOSTLY_Q5_1 = 9,  // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q3_K = 11, // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
     };
 
     // available tensor operations:
@@ -282,12 +296,14 @@ extern "C" {
         GGML_OP_SUM_ROWS,
         GGML_OP_MEAN,
         GGML_OP_REPEAT,
+        GGML_OP_REPEAT_BACK,
         GGML_OP_ABS,
         GGML_OP_SGN,
         GGML_OP_NEG,
         GGML_OP_STEP,
         GGML_OP_RELU,
         GGML_OP_GELU,
+        GGML_OP_GELU_QUICK,
         GGML_OP_SILU,
         GGML_OP_SILU_BACK,
         GGML_OP_NORM, // normalize
@@ -295,6 +311,7 @@ extern "C" {
         GGML_OP_RMS_NORM_BACK,
 
         GGML_OP_MUL_MAT,
+        GGML_OP_OUT_PROD,
 
         GGML_OP_SCALE,
         GGML_OP_SET,
@@ -310,19 +327,31 @@ extern "C" {
         GGML_OP_DIAG_MASK_INF,
         GGML_OP_DIAG_MASK_ZERO,
         GGML_OP_SOFT_MAX,
+        GGML_OP_SOFT_MAX_BACK,
         GGML_OP_ROPE,
         GGML_OP_ROPE_BACK,
         GGML_OP_ALIBI,
         GGML_OP_CLAMP,
-        GGML_OP_CONV_1D_1S,
-        GGML_OP_CONV_1D_2S,
+        GGML_OP_CONV_1D_S1_PH,
+        GGML_OP_CONV_1D_S2_PH,
+        GGML_OP_CONV_2D_SK_P0,
 
         GGML_OP_FLASH_ATTN,
         GGML_OP_FLASH_FF,
+        GGML_OP_FLASH_ATTN_BACK,
+        GGML_OP_WIN_PART,
+        GGML_OP_WIN_UNPART,
 
         GGML_OP_MAP_UNARY,
         GGML_OP_MAP_BINARY,
 
+        GGML_OP_MAP_CUSTOM1,
+        GGML_OP_MAP_CUSTOM2,
+        GGML_OP_MAP_CUSTOM3,
+
+        GGML_OP_CROSS_ENTROPY_LOSS,
+        GGML_OP_CROSS_ENTROPY_LOSS_BACK,
+
         GGML_OP_COUNT,
     };
 
@@ -371,11 +400,15 @@ extern "C" {
 
         void * data;
 
-        char name[32];
+        char name[GGML_MAX_NAME];
+
+        void * extra; // extra things e.g. for ggml-cuda.cu
 
-        char padding[16];
+        char padding[4];
     };
 
+    static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
+
     // computation graph
     struct ggml_cgraph {
         int n_nodes;
@@ -409,6 +442,25 @@ extern "C" {
         bool   no_alloc;   // don't allocate memory for the tensor data
     };
 
+
+    // compute types
+    enum ggml_task_type {
+        GGML_TASK_INIT = 0,
+        GGML_TASK_COMPUTE,
+        GGML_TASK_FINALIZE,
+    };
+
+    struct ggml_compute_params {
+        enum ggml_task_type type;
+
+        // ith = thread index, nth = number of threads
+        int ith, nth;
+
+        // work buffer for all threads
+        size_t wsize;
+        void * wdata;
+    };
+
     // misc
 
     GGML_API void    ggml_time_init(void); // call this once at the beginning of the program
@@ -420,14 +472,17 @@ extern "C" {
     GGML_API void    ggml_print_object (const struct ggml_object * obj);
     GGML_API void    ggml_print_objects(const struct ggml_context * ctx);
 
-    GGML_API int64_t ggml_nelements(const struct ggml_tensor * tensor);
-    GGML_API size_t  ggml_nbytes   (const struct ggml_tensor * tensor);
+    GGML_API int64_t ggml_nelements   (const struct ggml_tensor * tensor);
+    GGML_API int64_t ggml_nrows       (const struct ggml_tensor * tensor);
+    GGML_API size_t  ggml_nbytes      (const struct ggml_tensor * tensor);
+    GGML_API size_t  ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split);
 
     GGML_API int     ggml_blck_size (enum ggml_type type);
     GGML_API size_t  ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
     GGML_API float   ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float
 
     GGML_API const char * ggml_type_name(enum ggml_type type);
+    GGML_API const char * ggml_op_name  (enum ggml_op   op);
 
     GGML_API size_t  ggml_element_size(const struct ggml_tensor * tensor);
 
@@ -436,14 +491,26 @@ extern "C" {
     // TODO: temporary until model loading of ggml examples is refactored
     GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);
 
+    GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor);
+    GGML_API bool ggml_is_contiguous(const struct ggml_tensor * tensor);
+    GGML_API bool ggml_is_permuted  (const struct ggml_tensor * tensor);
+
+    // use this to compute the memory overhead of a tensor
+    GGML_API size_t ggml_tensor_overhead(void);
+
     // main
 
     GGML_API struct ggml_context * ggml_init(struct ggml_init_params params);
-    GGML_API void    ggml_free(struct ggml_context * ctx);
+    GGML_API void                  ggml_free(struct ggml_context * ctx);
 
     GGML_API size_t  ggml_used_mem(const struct ggml_context * ctx);
 
-    GGML_API size_t  ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch);
+    GGML_API size_t  ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch);
+    GGML_API void    ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
+
+    GGML_API void *  ggml_get_mem_buffer     (const struct ggml_context * ctx);
+    GGML_API size_t  ggml_get_mem_size       (const struct ggml_context * ctx);
+    GGML_API size_t  ggml_get_max_tensor_size(const struct ggml_context * ctx);
 
     GGML_API struct ggml_tensor * ggml_new_tensor(
             struct ggml_context * ctx,
@@ -483,6 +550,8 @@ extern "C" {
     GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
     GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src);
 
+    GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name);
+
     GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
     GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
     GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
@@ -496,8 +565,9 @@ extern "C" {
     GGML_API void *  ggml_get_data    (const struct ggml_tensor * tensor);
     GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
 
-    GGML_API const char * ggml_get_name(const struct ggml_tensor * tensor);
-    GGML_API void         ggml_set_name(struct ggml_tensor * tensor, const char * name);
+    GGML_API const char *         ggml_get_name(const struct ggml_tensor * tensor);
+    GGML_API struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name);
+    GGML_API struct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char * fmt, ...);
 
     //
     // operations on tensors with backpropagation
@@ -522,6 +592,11 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    GGML_API struct ggml_tensor * ggml_add1_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
     GGML_API struct ggml_tensor * ggml_acc(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -545,24 +620,47 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    GGML_API struct ggml_tensor * ggml_sub_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
     GGML_API struct ggml_tensor * ggml_mul(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    GGML_API struct ggml_tensor * ggml_mul_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
     GGML_API struct ggml_tensor * ggml_div(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    GGML_API struct ggml_tensor * ggml_div_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
     GGML_API struct ggml_tensor * ggml_sqr(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_sqr_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     GGML_API struct ggml_tensor * ggml_sqrt(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_sqrt_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     GGML_API struct ggml_tensor * ggml_log(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
@@ -593,35 +691,76 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    GGML_API struct ggml_tensor * ggml_repeat_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
     GGML_API struct ggml_tensor * ggml_abs(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_abs_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     GGML_API struct ggml_tensor * ggml_sgn(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_sgn_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     GGML_API struct ggml_tensor * ggml_neg(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_neg_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     GGML_API struct ggml_tensor * ggml_step(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_step_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     GGML_API struct ggml_tensor * ggml_relu(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_relu_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     // TODO: double-check this computation is correct
     GGML_API struct ggml_tensor * ggml_gelu(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_gelu_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_gelu_quick(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_gelu_quick_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     GGML_API struct ggml_tensor * ggml_silu(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_silu_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     // a - x
     // b - dy
     GGML_API struct ggml_tensor * ggml_silu_back(
@@ -635,10 +774,18 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_norm_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     GGML_API struct ggml_tensor * ggml_rms_norm(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_rms_norm_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     // a - x
     // b - dy
     GGML_API struct ggml_tensor * ggml_rms_norm_back(
@@ -646,14 +793,22 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
-    // A: m rows, n columns
-    // B: p rows, n columns (i.e. we transpose it internally)
+    // A: n columns, m rows
+    // B: n columns, p rows  (i.e. we transpose it internally)
     // result is m columns, p rows
     GGML_API struct ggml_tensor * ggml_mul_mat(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    // A: m columns, n rows,
+    // B: p columns, n rows,
+    // result is m columns, p rows
+    GGML_API struct ggml_tensor * ggml_out_prod(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
     //
     // operations on tensors without backpropagation
     //
@@ -864,6 +1019,17 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_soft_max_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_soft_max_back_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
     // rotary position embedding
     // if mode & 1 == 1, skip n_past elements
     // if mode & 2 == 1, GPT-NeoX style
@@ -909,16 +1075,55 @@ extern "C" {
             float                 min,
             float                 max);
 
-    // padding = 1
+    // TODO: implement general-purpose convolutions
+    // GGML_API struct ggml_tensor * ggml_conv_1d(
+    //        struct ggml_context * ctx,
+    //        struct ggml_tensor  * a,
+    //        struct ggml_tensor  * b,
+    //        int                   s0
+    //        int                   p0,
+    //        int                   d0);
+    //
+    // GGML_API struct ggml_tensor * ggml_conv_2d(
+    //        struct ggml_context * ctx,
+    //        struct ggml_tensor  * a,
+    //        struct ggml_tensor  * b,
+    //        int                   s0,
+    //        int                   s1,
+    //        int                   p0,
+    //        int                   p1,
+    //        int                   d0,
+    //        int                   d1);
+
+    // padding = half
     // TODO: we don't support extra parameters for now
     //       that's why we are hard-coding the stride, padding, and dilation
     //       not great ..
-    GGML_API struct ggml_tensor * ggml_conv_1d_1s(
+    // example:
+    // a:      3   80  768    1
+    // b:   3000   80    1    1
+    // res: 3000  768    1    1
+    // used in whisper
+    GGML_API struct ggml_tensor * ggml_conv_1d_s1_ph(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
-    GGML_API struct ggml_tensor * ggml_conv_1d_2s(
+    // used in whisper
+    GGML_API struct ggml_tensor * ggml_conv_1d_s2_ph(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // kernel size is a->ne[0] x a->ne[1]
+    // stride is equal to kernel size
+    // padding is zero
+    // example:
+    // a:     16   16    3  768
+    // b:   1024 1024    3    1
+    // res:   64   64  768    1
+    // used in sam
+    GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
@@ -930,6 +1135,14 @@ extern "C" {
             struct ggml_tensor  * v,
             bool                  masked);
 
+    GGML_API struct ggml_tensor * ggml_flash_attn_back(
+           struct ggml_context * ctx,
+           struct ggml_tensor  * q,
+           struct ggml_tensor  * k,
+           struct ggml_tensor  * v,
+           struct ggml_tensor  * d,
+           bool                  masked);
+
     GGML_API struct ggml_tensor * ggml_flash_ff(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -938,21 +1151,106 @@ extern "C" {
             struct ggml_tensor  * c0,
             struct ggml_tensor  * c1);
 
-    // Mapping operations
-    typedef void (*ggml_unary_op_f32_t)(const int, float *, const float *);
+    // partition into non-overlapping windows with padding if needed
+    // example:
+    // a:   768   64   64    1
+    // w:    14
+    // res: 768   14   14    25
+    // used in sam
+    GGML_API struct ggml_tensor * ggml_win_part(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   w);
+
+    // reverse of ggml_win_part
+    // used in sam
+    GGML_API struct ggml_tensor * ggml_win_unpart(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   w0,
+            int                   h0,
+            int                   w);
+
+    // custom operators
+
+    typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
     typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
 
+    typedef void (*ggml_custom1_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *);
+    typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
+    typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
+
     GGML_API struct ggml_tensor * ggml_map_unary_f32(
             struct ggml_context        * ctx,
             struct ggml_tensor         * a,
                    ggml_unary_op_f32_t   fun);
 
+    GGML_API struct ggml_tensor * ggml_map_unary_inplace_f32(
+            struct ggml_context        * ctx,
+            struct ggml_tensor         * a,
+                   ggml_unary_op_f32_t   fun);
+
     GGML_API struct ggml_tensor * ggml_map_binary_f32(
             struct ggml_context         * ctx,
             struct ggml_tensor          * a,
             struct ggml_tensor          * b,
                    ggml_binary_op_f32_t   fun);
 
+    GGML_API struct ggml_tensor * ggml_map_binary_inplace_f32(
+            struct ggml_context         * ctx,
+            struct ggml_tensor          * a,
+            struct ggml_tensor          * b,
+                   ggml_binary_op_f32_t   fun);
+
+    GGML_API struct ggml_tensor * ggml_map_custom1_f32(
+            struct ggml_context          * ctx,
+            struct ggml_tensor           * a,
+                   ggml_custom1_op_f32_t   fun);
+
+    GGML_API struct ggml_tensor * ggml_map_custom1_inplace_f32(
+            struct ggml_context          * ctx,
+            struct ggml_tensor           * a,
+                   ggml_custom1_op_f32_t   fun);
+
+    GGML_API struct ggml_tensor * ggml_map_custom2_f32(
+            struct ggml_context          * ctx,
+            struct ggml_tensor           * a,
+            struct ggml_tensor           * b,
+                   ggml_custom2_op_f32_t   fun);
+
+    GGML_API struct ggml_tensor * ggml_map_custom2_inplace_f32(
+            struct ggml_context          * ctx,
+            struct ggml_tensor           * a,
+            struct ggml_tensor           * b,
+                   ggml_custom2_op_f32_t   fun);
+
+    GGML_API struct ggml_tensor * ggml_map_custom3_f32(
+            struct ggml_context          * ctx,
+            struct ggml_tensor           * a,
+            struct ggml_tensor           * b,
+            struct ggml_tensor           * c,
+                   ggml_custom3_op_f32_t   fun);
+
+    GGML_API struct ggml_tensor * ggml_map_custom3_inplace_f32(
+            struct ggml_context          * ctx,
+            struct ggml_tensor           * a,
+            struct ggml_tensor           * b,
+            struct ggml_tensor           * c,
+                   ggml_custom3_op_f32_t   fun);
+
+    // loss function
+
+    GGML_API struct ggml_tensor * ggml_cross_entropy_loss(
+            struct ggml_context         * ctx,
+            struct ggml_tensor          * a,
+            struct ggml_tensor          * b);
+
+    GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back(
+            struct ggml_context         * ctx,
+            struct ggml_tensor          * a,
+            struct ggml_tensor          * b,
+            struct ggml_tensor          * c);
+
     //
     // automatic differentiation
     //
@@ -969,6 +1267,11 @@ extern "C" {
     GGML_API void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph);
     GGML_API void ggml_graph_reset  (struct ggml_cgraph * cgraph);
 
+    GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
+
+    GGML_API void               ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname);
+    GGML_API struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval);
+
     // print info and performance information for the graph
     GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
 
@@ -1042,6 +1345,8 @@ extern "C" {
         struct {
             int n_iter;
 
+            float sched; // schedule multiplier (fixed, decay or warmup)
+            float decay; // weight decay for AdamW, use 0.0f to disable
             float alpha; // learning rate
             float beta1;
             float beta2;
@@ -1066,6 +1371,49 @@ extern "C" {
         } lbfgs;
     };
 
+    struct ggml_opt_context {
+        struct ggml_context * ctx;
+        struct ggml_opt_params params;
+
+        int iter;
+        int64_t nx; // number of parameter elements
+
+        bool just_initialized;
+
+        struct {
+            struct ggml_tensor * x;  // view of the parameters
+            struct ggml_tensor * g1; // gradient
+            struct ggml_tensor * g2; // gradient squared
+            struct ggml_tensor * m;  // first moment
+            struct ggml_tensor * v;  // second moment
+            struct ggml_tensor * mh; // first moment hat
+            struct ggml_tensor * vh; // second moment hat
+            struct ggml_tensor * pf; // past function values
+            float fx_best;
+            float fx_prev;
+            int n_no_improvement;
+        } adam;
+
+        struct {
+            struct ggml_tensor * x;    // current parameters
+            struct ggml_tensor * xp;   // previous parameters
+            struct ggml_tensor * g;    // current gradient
+            struct ggml_tensor * gp;   // previous gradient
+            struct ggml_tensor * d;    // search direction
+            struct ggml_tensor * pf;   // past function values
+            struct ggml_tensor * lmal; // the L-BFGS memory alpha
+            struct ggml_tensor * lmys; // the L-BFGS memory ys
+            struct ggml_tensor * lms;  // the L-BFGS memory s
+            struct ggml_tensor * lmy;  // the L-BFGS memory y
+            float fx_best;
+            float step;
+            int j;
+            int k;
+            int end;
+            int n_no_improvement;
+        } lbfgs;
+    };
+
     GGML_API struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
 
     // optimize the function defined by the tensor f
@@ -1074,6 +1422,27 @@ extern "C" {
             struct ggml_opt_params params,
             struct ggml_tensor * f);
 
+    // initialize optimizer context
+    GGML_API void ggml_opt_init(
+            struct ggml_context * ctx,
+            struct ggml_opt_context * opt,
+            struct ggml_opt_params params,
+            int64_t nx);
+
+    // continue optimizing the function defined by the tensor f
+    GGML_API enum ggml_opt_result ggml_opt_resume(
+            struct ggml_context * ctx,
+            struct ggml_opt_context * opt,
+            struct ggml_tensor * f);
+
+    // continue optimizing the function defined by the tensor f
+    GGML_API enum ggml_opt_result ggml_opt_resume_g(
+            struct ggml_context * ctx,
+            struct ggml_opt_context * opt,
+            struct ggml_tensor * f,
+            struct ggml_cgraph * gf,
+            struct ggml_cgraph * gb);
+
     //
     // quantization
     //
index 0cdd4a1d49fe0e9cc0cdd504bd7a739602ea648b..65c57f7d0d5d5ead4688c9f91afa94cfb2e29f7d 100644 (file)
 #include <regex>
 #include <random>
 
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
 #if defined(GGML_BIG_ENDIAN)
 #include <bit>
 
@@ -1468,7 +1472,7 @@ static bool whisper_encode_internal(
         {
             wstate.use_buf(ctx0, 1);
 
-            cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
+            cur = ggml_conv_1d_s1_ph(ctx0, model.e_conv_1_w, mel);
             cur = ggml_add(ctx0,
                     ggml_repeat(ctx0,
                         model.e_conv_1_b,
@@ -1479,7 +1483,7 @@ static bool whisper_encode_internal(
 
             wstate.use_buf(ctx0, 0);
 
-            cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
+            cur = ggml_conv_1d_s2_ph(ctx0, model.e_conv_2_w, cur);
             cur = ggml_add(ctx0,
                     ggml_repeat(ctx0,
                         model.e_conv_2_b,