]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
tts : add OuteTTS support (#10784)
authorGeorgi Gerganov <redacted>
Wed, 18 Dec 2024 17:27:21 +0000 (19:27 +0200)
committerGitHub <redacted>
Wed, 18 Dec 2024 17:27:21 +0000 (19:27 +0200)
* server : add "tokens" output

ggml-ci

* server : output embeddings for all tokens when pooling = none

ggml-ci

* server : be explicit about the pooling type in the tests

ggml-ci

* server : do not normalize embeddings when there is no pooling

ggml-ci

* llama : add OuteTTS support (wip)

* wip

* extract features

* first conv

* group norm

* resnet conv

* resnet

* attn

* pos net

* layer norm

* convnext

* head

* hann window

* fix n_embd + remove llama.cpp hacks

* compute hann window

* fft

* spectrum processing

* clean-up

* tts : receive input text and generate codes

* clip : fix new conv name

* tts : minor fix

* tts : add header + minor fixes

ggml-ci

* tts : add matchematical constant

ggml-ci

* tts : fix sampling + cut initial noise

* tts : fixes

* tts : update default samplers

ggml-ci

* tts : text pre-processing

* tts : outetts-voc -> wavtokenizer-dec

* tts : remove hardcoded constants

ggml-ci

* tts : fix tensor shapes

* llama : refactor wavtokenizer tensors

ggml-ci

* cont

ggml-ci

* cont [no ci]

* llama : update WavTokenizer to non-causal attn

* llama : handle no-vocab detokenization

* tts : add Python example for OuteTTS (wip)

* tts : extend python example to generate spectrogram

ggml-ci

* server : fix rebase artifacts

* tts : enable "return_tokens" in Python example

ggml-ci

* tts : minor fixes

* common : support HF download for vocoder

19 files changed:
common/arg.cpp
common/common.cpp
common/common.h
convert_hf_to_gguf.py
examples/CMakeLists.txt
examples/llava/clip.cpp
examples/tts/CMakeLists.txt [new file with mode: 0644]
examples/tts/convert_pt_to_hf.py [new file with mode: 0644]
examples/tts/tts-outetts.py [new file with mode: 0644]
examples/tts/tts.cpp [new file with mode: 0644]
ggml/include/ggml.h
ggml/src/ggml.c
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
gguf-py/gguf/tensor_mapping.py
gguf-py/tests/test_quants.py
include/llama.h
src/llama-vocab.cpp
src/llama.cpp

index 3d55289c331922d39560c20f449b537d31b1198b..e5ddd8318f787cf4fd8067ab813337f6e9e7cd10 100644 (file)
@@ -119,29 +119,33 @@ std::string common_arg::to_string() {
 // utils
 //
 
-static void common_params_handle_model_default(common_params & params) {
-    if (!params.hf_repo.empty()) {
+static void common_params_handle_model_default(
+        std::string & model,
+        std::string & model_url,
+        std::string & hf_repo,
+        std::string & hf_file) {
+    if (!hf_repo.empty()) {
         // short-hand to avoid specifying --hf-file -> default it to --model
-        if (params.hf_file.empty()) {
-            if (params.model.empty()) {
+        if (hf_file.empty()) {
+            if (model.empty()) {
                 throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n");
             }
-            params.hf_file = params.model;
-        } else if (params.model.empty()) {
+            hf_file = model;
+        } else if (model.empty()) {
             // this is to avoid different repo having same file name, or same file name in different subdirs
-            std::string filename = params.hf_repo + "_" + params.hf_file;
+            std::string filename = hf_repo + "_" + hf_file;
             // to make sure we don't have any slashes in the filename
             string_replace_all(filename, "/", "_");
-            params.model = fs_get_cache_file(filename);
+            model = fs_get_cache_file(filename);
         }
-    } else if (!params.model_url.empty()) {
-        if (params.model.empty()) {
-            auto f = string_split<std::string>(params.model_url, '#').front();
+    } else if (!model_url.empty()) {
+        if (model.empty()) {
+            auto f = string_split<std::string>(model_url, '#').front();
             f = string_split<std::string>(f, '?').front();
-            params.model = fs_get_cache_file(string_split<std::string>(f, '/').back());
+            model = fs_get_cache_file(string_split<std::string>(f, '/').back());
         }
-    } else if (params.model.empty()) {
-        params.model = DEFAULT_MODEL_PATH;
+    } else if (model.empty()) {
+        model = DEFAULT_MODEL_PATH;
     }
 }
 
@@ -276,7 +280,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
         throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
     }
 
-    common_params_handle_model_default(params);
+    // TODO: refactor model params in a common struct
+    common_params_handle_model_default(params.model,         params.model_url,         params.hf_repo,         params.hf_file);
+    common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file);
 
     if (params.escape) {
         string_process_escapes(params.prompt);
@@ -842,7 +848,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         }
     ).set_sparam());
     add_opt(common_arg(
-        {"--sampling-seq"}, "SEQUENCE",
+        {"--sampling-seq", "--sampler-seq"}, "SEQUENCE",
         string_format("simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str()),
         [](common_params & params, const std::string & value) {
             params.sampling.samplers = common_sampler_types_from_chars(value);
@@ -1581,6 +1587,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.hf_file = value;
         }
     ).set_env("LLAMA_ARG_HF_FILE"));
+    add_opt(common_arg(
+        {"-hfrv", "--hf-repo-v"}, "REPO",
+        "Hugging Face model repository for the vocoder model (default: unused)",
+        [](common_params & params, const std::string & value) {
+            params.vocoder.hf_repo = value;
+        }
+    ).set_env("LLAMA_ARG_HF_REPO_V"));
+    add_opt(common_arg(
+        {"-hffv", "--hf-file-v"}, "FILE",
+        "Hugging Face model file for the vocoder model (default: unused)",
+        [](common_params & params, const std::string & value) {
+            params.vocoder.hf_file = value;
+        }
+    ).set_env("LLAMA_ARG_HF_FILE_V"));
     add_opt(common_arg(
         {"-hft", "--hf-token"}, "TOKEN",
         "Hugging Face access token (default: value from HF_TOKEN environment variable)",
@@ -2178,5 +2198,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         }
     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT"));
 
+    add_opt(common_arg(
+        {"-mv", "--model-vocoder"}, "FNAME",
+        "vocoder model for audio generation (default: unused)",
+        [](common_params & params, const std::string & value) {
+            params.vocoder.model = value;
+        }
+    ).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
+
     return ctx_arg;
 }
index 05d3ba766e38bf2490f00ad65cce157af25265d1..20be9291161ca4de3c0fa15e9f8a236d7c2bb99b 100644 (file)
@@ -1095,7 +1095,7 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p
 #define CURL_MAX_RETRY 3
 #define CURL_RETRY_DELAY_SECONDS 2
 
-static bool curl_perform_with_retry(const std::string& url, CURL* curl, int max_attempts, int retry_delay_seconds) {
+static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds) {
     int remaining_attempts = max_attempts;
 
     while (remaining_attempts > 0) {
@@ -1119,7 +1119,6 @@ static bool curl_perform_with_retry(const std::string& url, CURL* curl, int max_
 }
 
 static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
-
     // Initialize libcurl
     std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
     if (!curl) {
@@ -1192,11 +1191,13 @@ static bool common_download_file(const std::string & url, const std::string & pa
         std::string etag;
         std::string last_modified;
     };
+
     common_load_model_from_url_headers headers;
+
     {
         typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
         auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
-            common_load_model_from_url_headers *headers = (common_load_model_from_url_headers *) userdata;
+            common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
 
             static std::regex header_regex("([^:]+): (.*)\r\n");
             static std::regex etag_regex("ETag", std::regex_constants::icase);
index ec0e49f6f1806e3a6da5008a689a850791e73136..1d2bd932c211dc6132198619bf1b075c6262cdfb 100644 (file)
@@ -80,6 +80,7 @@ enum llama_example {
     LLAMA_EXAMPLE_LLAVA,
     LLAMA_EXAMPLE_LOOKUP,
     LLAMA_EXAMPLE_PARALLEL,
+    LLAMA_EXAMPLE_TTS,
 
     LLAMA_EXAMPLE_COUNT,
 };
@@ -159,6 +160,7 @@ struct common_params_sampling {
 
 struct common_params_speculative {
     std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
+
     int32_t n_ctx        =     0; // draft context size
     int32_t n_max        =    16; // maximum number of tokens to draft during speculative decoding
     int32_t n_min        =     5; // minimum number of draft tokens to use for speculative decoding
@@ -172,6 +174,14 @@ struct common_params_speculative {
     std::string model = ""; // draft model for speculative decoding                          // NOLINT
 };
 
+struct common_params_vocoder {
+    std::string hf_repo = ""; // HF repo                                                     // NOLINT
+    std::string hf_file = ""; // HF file                                                     // NOLINT
+
+    std::string model     = ""; // model path                                                // NOLINT
+    std::string model_url = ""; // model url to download                                     // NOLINT
+};
+
 struct common_params {
     int32_t n_predict             =    -1; // new tokens to predict
     int32_t n_ctx                 =  4096; // context size
@@ -214,8 +224,9 @@ struct common_params {
     enum llama_pooling_type      pooling_type      = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
     enum llama_attention_type    attention_type    = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
 
-    struct common_params_sampling sampling;
+    struct common_params_sampling    sampling;
     struct common_params_speculative speculative;
+    struct common_params_vocoder     vocoder;
 
     std::string model                = ""; // model path                                                    // NOLINT
     std::string model_alias          = ""; // model alias                                                   // NOLINT
index 9dc1673bc2c06cddaf1a24b38a14985f8162c23d..4a0b00f69c6994e5b6b81f1d420dccad77e3d409 100755 (executable)
@@ -221,17 +221,17 @@ class Model:
             self.gguf_writer.add_context_length(n_ctx)
             logger.info(f"gguf: context length = {n_ctx}")
 
-        n_embd = self.find_hparam(["hidden_size", "n_embd"])
-        self.gguf_writer.add_embedding_length(n_embd)
-        logger.info(f"gguf: embedding length = {n_embd}")
+        if (n_embd := self.find_hparam(["hidden_size", "n_embd"], optional=True)) is not None:
+            self.gguf_writer.add_embedding_length(n_embd)
+            logger.info(f"gguf: embedding length = {n_embd}")
 
         if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None:
             self.gguf_writer.add_feed_forward_length(n_ff)
             logger.info(f"gguf: feed forward length = {n_ff}")
 
-        n_head = self.find_hparam(["num_attention_heads", "n_head"])
-        self.gguf_writer.add_head_count(n_head)
-        logger.info(f"gguf: head count = {n_head}")
+        if (n_head := self.find_hparam(["num_attention_heads", "n_head"], optional=True)) is not None:
+            self.gguf_writer.add_head_count(n_head)
+            logger.info(f"gguf: head count = {n_head}")
 
         if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
             self.gguf_writer.add_head_count_kv(n_head_kv)
@@ -296,7 +296,9 @@ class Model:
                     break
 
             for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
-                data = data_torch.squeeze().numpy()
+                # TODO: why do we squeeze here?
+                # data = data_torch.squeeze().numpy()
+                data = data_torch.numpy()
 
                 # if data ends up empty, it means data_torch was a scalar tensor -> restore
                 if len(data.shape) == 0:
@@ -324,6 +326,8 @@ class Model:
                             gguf.MODEL_TENSOR.TIME_MIX_W2,
                             gguf.MODEL_TENSOR.TIME_MIX_DECAY_W1,
                             gguf.MODEL_TENSOR.TIME_MIX_DECAY_W2,
+                            gguf.MODEL_TENSOR.POSNET_NORM1,
+                            gguf.MODEL_TENSOR.POSNET_NORM2,
                         )
                     )
                     or not new_name.endswith(".weight")
@@ -689,6 +693,9 @@ class Model:
         return res
         # Marker: End get_vocab_base_pre
 
+    def _set_vocab_none(self) -> None:
+        self.gguf_writer.add_tokenizer_model("none")
+
     def _set_vocab_gpt2(self) -> None:
         tokens, toktypes, tokpre = self.get_vocab_base()
         self.gguf_writer.add_tokenizer_model("gpt2")
@@ -2027,6 +2034,44 @@ class Qwen2VLModel(Model):
             yield name, data
 
 
+@Model.register("WavTokenizerDec")
+class WavTokenizerDecModel(Model):
+    model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        del bid  # unused
+
+        if \
+                name.endswith("codebook.cluster_size") or \
+                name.endswith("codebook.embed_avg") or \
+                name.endswith("codebook.inited"):
+            logger.debug(f"Skipping {name!r}")
+            return []
+
+        logger.info(f"{self.map_tensor_name(name)} -> {data_torch.shape}")
+
+        return [(self.map_tensor_name(name), data_torch)]
+
+    def set_vocab(self):
+        self._set_vocab_none()
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        self.gguf_writer.add_vocab_size         (self.hparams["vocab_size"])
+        self.gguf_writer.add_features_length    (self.hparams["n_embd_features"])
+        self.gguf_writer.add_feed_forward_length(self.hparams["n_ff"])
+        self.gguf_writer.add_group_norm_eps     (self.hparams["group_norm_epsilon"])
+        self.gguf_writer.add_group_norm_groups  (self.hparams["group_norm_groups"])
+
+        self.gguf_writer.add_posnet_embedding_length(self.hparams["posnet"]["n_embd"])
+        self.gguf_writer.add_posnet_block_count     (self.hparams["posnet"]["n_layer"])
+
+        self.gguf_writer.add_convnext_embedding_length(self.hparams["convnext"]["n_embd"])
+        self.gguf_writer.add_convnext_block_count     (self.hparams["convnext"]["n_layer"])
+
+        self.gguf_writer.add_causal_attention(False)
+
+
 @Model.register("Qwen2MoeForCausalLM")
 class Qwen2MoeModel(Model):
     model_arch = gguf.MODEL_ARCH.QWEN2MOE
index 21b31392e81d087610256c39d215b806b7b17825..66cfab2c3b7962b01ac1cff5718b4e3eeb8b7ce4 100644 (file)
@@ -51,6 +51,7 @@ else()
     add_subdirectory(speculative)
     add_subdirectory(speculative-simple)
     add_subdirectory(tokenize)
+    add_subdirectory(tts)
     add_subdirectory(gen-docs)
     if (NOT GGML_BACKEND_DL)
         # these examples use the backends directly and cannot be built with dynamic loading
index ba28c07c6aeeca809aaa2bf18c3cbe9da26bc772..463b7c865b90c8bbc808e1b599264fef9400888d 100644 (file)
@@ -896,7 +896,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
                 mlp_3 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_3, 1, 0, 2, 3));
                 mlp_3 = ggml_reshape_4d(ctx0, mlp_3, n_patch, n_patch, mlp_3->ne[1], mlp_3->ne[2]);
                 // stride = 1, padding = 1, bias is nullptr
-                block_1 = ggml_conv_depthwise_2d(ctx0, model.mm_model_block_1_block_0_0_w, mlp_3, 1, 1, 1, 1, 1, 1);
+                block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_1_block_0_0_w, mlp_3, 1, 1, 1, 1, 1, 1);
 
                 // layer norm
                 // // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
@@ -944,7 +944,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
             // block_2
             {
                 // stride = 2
-                block_1 = ggml_conv_depthwise_2d(ctx0, model.mm_model_block_2_block_0_0_w, block_1, 2, 2, 1, 1, 1, 1);
+                block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_2_block_0_0_w, block_1, 2, 2, 1, 1, 1, 1);
 
                 // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1]
                 // layer norm
@@ -1005,7 +1005,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
             // mlp_2 ne [24, 24, 2048, 1]
             mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0);
             // weight ne = [3, 3, 2048, 1]
-            struct ggml_tensor * peg_0 = ggml_conv_depthwise_2d(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1);
+            struct ggml_tensor * peg_0 = ggml_conv_2d_dw(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1);
             peg_0 = ggml_cont(ctx0, ggml_permute(ctx0, peg_0, 1, 2, 0, 3));
             peg_0 = ggml_add(ctx0, peg_0, model.mm_model_peg_0_b);
             mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 2, 0, 3));
diff --git a/examples/tts/CMakeLists.txt b/examples/tts/CMakeLists.txt
new file mode 100644 (file)
index 0000000..c72bd81
--- /dev/null
@@ -0,0 +1,5 @@
+set(TARGET llama-tts)
+add_executable(${TARGET} tts.cpp)
+install(TARGETS ${TARGET} RUNTIME)
+target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_17)
diff --git a/examples/tts/convert_pt_to_hf.py b/examples/tts/convert_pt_to_hf.py
new file mode 100644 (file)
index 0000000..8909a65
--- /dev/null
@@ -0,0 +1,180 @@
+# convert the https://huggingface.co/novateur/WavTokenizer-large-speech-75token to HF format
+# the goal is to be able to reuse the convert_hf_to_gguf.py after that to create a GGUF file with the WavTokenizer decoder
+#
+# TODO: this script is LLM-generated and probably very inefficient and should be rewritten
+
+import torch
+import json
+import os
+import sys
+import re
+
+from safetensors.torch import save_file
+
+# default
+model_path = './model.pt';
+
+# read from CLI
+if len(sys.argv) > 1:
+    model_path = sys.argv[1]
+
+# get the directory of the input model
+path_dst = os.path.dirname(model_path)
+
+print(f"Loading model from {model_path}")
+
+model = torch.load(model_path, map_location='cpu')
+
+#print(model)
+
+# print all keys
+for key in model.keys():
+    print(key)
+    if key == 'hyper_parameters':
+        #print(model[key])
+        # dump as json pretty
+        print(json.dumps(model[key], indent=4))
+    #if key != 'state_dict' and key != 'optimizer_states':
+    #    print(model[key])
+
+# Check if the loaded model is a state_dict or a model instance
+if isinstance(model, torch.nn.Module):
+    state_dict = model.state_dict()
+else:
+    state_dict = model
+
+# Print the structure of the state_dict to understand its format
+print("State dictionary keys:")
+for key in state_dict.keys():
+    print(key)
+
+# Ensure the state_dict is flat and contains only torch.Tensor objects
+def flatten_state_dict(state_dict, parent_key='', sep='.'):
+    items = []
+    items_new = []
+
+    for k, v in state_dict.items():
+        new_key = f"{parent_key}{sep}{k}" if parent_key else k
+        if isinstance(v, torch.Tensor):
+            items.append((new_key, v))
+        elif isinstance(v, dict):
+            items.extend(flatten_state_dict(v, new_key, sep=sep).items())
+            return dict(items)
+
+    size_total_mb = 0
+
+    for key, value in list(items):
+        # keep only what we need for inference
+        if not key.startswith('state_dict.feature_extractor.encodec.quantizer.') and \
+           not key.startswith('state_dict.backbone.') and \
+           not key.startswith('state_dict.head.out'):
+               print('Skipping key: ', key)
+               continue
+
+        new_key = key
+
+        new_key = new_key.replace('state_dict.', '')
+        new_key = new_key.replace('pos_net', 'posnet')
+
+        # check if matches "backbone.posnet.%d.bias" or "backbone.posnet.%d.weight"
+        if new_key.startswith("backbone.posnet."):
+            match = re.match(r"backbone\.posnet\.(\d+)\.(bias|weight)", new_key)
+            if match:
+               new_key = f"backbone.posnet.{match.group(1)}.norm.{match.group(2)}"
+
+        # "feature_extractor.encodec.quantizer.vq.layers.0._codebook.embed" -> "backbone.embedding.weight"
+        if new_key == "feature_extractor.encodec.quantizer.vq.layers.0._codebook.embed":
+            new_key = "backbone.embedding.weight"
+
+        # these are the only rows used
+        # ref: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/wav_tokenizer/audio_codec.py#L100
+        if new_key.endswith("norm.scale.weight"):
+            new_key = new_key.replace("norm.scale.weight", "norm.weight")
+            value = value[0]
+
+        if new_key.endswith("norm.shift.weight"):
+            new_key = new_key.replace("norm.shift.weight", "norm.bias")
+            value = value[0]
+
+        if new_key.endswith("gamma"):
+            new_key = new_key.replace("gamma", "gamma.weight")
+
+        # convert from 1D [768] to 2D [768, 1] so that ggml_add can broadcast the bias
+        if (new_key.endswith("norm.weight") or new_key.endswith("norm1.weight") or new_key.endswith("norm2.weight") or new_key.endswith(".bias")) and (new_key.startswith("backbone.posnet") or new_key.startswith("backbone.embed.bias")):
+            value = value.unsqueeze(1)
+
+        if new_key.endswith("dwconv.bias"):
+            value = value.unsqueeze(1)
+
+        size_mb = value.element_size() * value.nelement() / (1024 * 1024)
+        print(f"{size_mb:8.2f} MB - {new_key}: {value.shape}")
+
+        size_total_mb += size_mb
+
+        #print(key, '->', new_key, ': ', value)
+        #print(key, '->', new_key)
+
+        items_new.append((new_key, value))
+
+    print(f"Total size: {size_total_mb:8.2f} MB")
+
+    return dict(items_new)
+
+flattened_state_dict = flatten_state_dict(state_dict)
+
+
+# Convert the model to the safetensors format
+output_path = path_dst + '/model.safetensors'
+save_file(flattened_state_dict, output_path)
+
+print(f"Model has been successfully converted and saved to {output_path}")
+
+# Calculate the total size of the .safetensors file
+total_size = os.path.getsize(output_path)
+
+# Create the weight map
+weight_map = {
+    "model.safetensors": ["*"]  # Assuming all weights are in one file
+}
+
+# Create metadata for the index.json file
+metadata = {
+    "total_size": total_size,
+    "weight_map": weight_map
+}
+
+# Save the metadata to index.json
+index_path = path_dst + '/index.json'
+with open(index_path, 'w') as f:
+    json.dump(metadata, f, indent=4)
+
+print(f"Metadata has been saved to {index_path}")
+
+config = {
+    "architectures": [
+        "WavTokenizerDec"
+    ],
+    "hidden_size": 1282,
+    "n_embd_features": 512,
+    "n_ff": 2304,
+    "vocab_size": 4096,
+    "n_head": 1,
+    "layer_norm_epsilon": 1e-6,
+    "group_norm_epsilon": 1e-6,
+    "group_norm_groups": 32,
+    "max_position_embeddings": 8192, # ?
+    "n_layer": 12,
+    "posnet": {
+        "n_embd": 768,
+        "n_layer": 6
+    },
+    "convnext": {
+        "n_embd": 768,
+        "n_layer": 12
+    },
+}
+
+with open(path_dst + '/config.json', 'w') as f:
+    json.dump(config, f, indent=4)
+
+print(f"Config has been saved to {path_dst + 'config.json'}")
diff --git a/examples/tts/tts-outetts.py b/examples/tts/tts-outetts.py
new file mode 100644 (file)
index 0000000..0f81192
--- /dev/null
@@ -0,0 +1,175 @@
+import sys
+#import json
+#import struct
+import requests
+import re
+
+def process_text(text: str):
+    text = re.sub(r'\d+(\.\d+)?', lambda x: x.group(), text.lower()) # TODO this needs to be fixed
+    text = re.sub(r'[-_/,\.\\]', ' ', text)
+    text = re.sub(r'[^a-z\s]', '', text)
+    text = re.sub(r'\s+', ' ', text).strip()
+    return text.split()
+
+# usage:
+# python tts-outetts.py http://server-llm:port http://server-dec:port "text"
+
+if len(sys.argv) <= 3:
+    print("usage: python tts-outetts.py http://server-llm:port http://server-dec:port \"text\"")
+    exit(1)
+
+host_llm = sys.argv[1]
+host_dec = sys.argv[2]
+text = sys.argv[3]
+
+prefix = """<|im_start|>
+<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>"""
+
+words = process_text(text)
+words = "<|text_sep|>".join([i.strip() for i in words])
+words += "<|text_end|>\n"
+
+# voice data
+# TODO: load from json
+#suffix = """<|audio_start|>
+#the<|t_0.08|><|code_start|><|257|><|740|><|636|><|913|><|788|><|1703|><|code_end|>
+#overall<|t_0.36|><|code_start|><|127|><|201|><|191|><|774|><|700|><|532|><|1056|><|557|><|798|><|298|><|1741|><|747|><|1662|><|1617|><|1702|><|1527|><|368|><|1588|><|1049|><|1008|><|1625|><|747|><|1576|><|728|><|1019|><|1696|><|1765|><|code_end|>
+#package<|t_0.56|><|code_start|><|935|><|584|><|1319|><|627|><|1016|><|1491|><|1344|><|1117|><|1526|><|1040|><|239|><|1435|><|951|><|498|><|723|><|1180|><|535|><|789|><|1649|><|1637|><|78|><|465|><|1668|><|901|><|595|><|1675|><|117|><|1009|><|1667|><|320|><|840|><|79|><|507|><|1762|><|1508|><|1228|><|1768|><|802|><|1450|><|1457|><|232|><|639|><|code_end|>
+#from<|t_0.19|><|code_start|><|604|><|782|><|1682|><|872|><|1532|><|1600|><|1036|><|1761|><|647|><|1554|><|1371|><|653|><|1595|><|950|><|code_end|>
+#just<|t_0.25|><|code_start|><|1782|><|1670|><|317|><|786|><|1748|><|631|><|599|><|1155|><|1364|><|1524|><|36|><|1591|><|889|><|1535|><|541|><|440|><|1532|><|50|><|870|><|code_end|>
+#two<|t_0.24|><|code_start|><|1681|><|1510|><|673|><|799|><|805|><|1342|><|330|><|519|><|62|><|640|><|1138|><|565|><|1552|><|1497|><|1552|><|572|><|1715|><|1732|><|code_end|>
+#people<|t_0.39|><|code_start|><|593|><|274|><|136|><|740|><|691|><|633|><|1484|><|1061|><|1138|><|1485|><|344|><|428|><|397|><|1562|><|645|><|917|><|1035|><|1449|><|1669|><|487|><|442|><|1484|><|1329|><|1832|><|1704|><|600|><|761|><|653|><|269|><|code_end|>
+#is<|t_0.16|><|code_start|><|566|><|583|><|1755|><|646|><|1337|><|709|><|802|><|1008|><|485|><|1583|><|652|><|10|><|code_end|>
+#pretty<|t_0.32|><|code_start|><|1818|><|1747|><|692|><|733|><|1010|><|534|><|406|><|1697|><|1053|><|1521|><|1355|><|1274|><|816|><|1398|><|211|><|1218|><|817|><|1472|><|1703|><|686|><|13|><|822|><|445|><|1068|><|code_end|>
+#remarkable<|t_0.68|><|code_start|><|230|><|1048|><|1705|><|355|><|706|><|1149|><|1535|><|1787|><|1356|><|1396|><|835|><|1583|><|486|><|1249|><|286|><|937|><|1076|><|1150|><|614|><|42|><|1058|><|705|><|681|><|798|><|934|><|490|><|514|><|1399|><|572|><|1446|><|1703|><|1346|><|1040|><|1426|><|1304|><|664|><|171|><|1530|><|625|><|64|><|1708|><|1830|><|1030|><|443|><|1509|><|1063|><|1605|><|1785|><|721|><|1440|><|923|><|code_end|>
+#sure<|t_0.36|><|code_start|><|792|><|1780|><|923|><|1640|><|265|><|261|><|1525|><|567|><|1491|><|1250|><|1730|><|362|><|919|><|1766|><|543|><|1|><|333|><|113|><|970|><|252|><|1606|><|133|><|302|><|1810|><|1046|><|1190|><|1675|><|code_end|>
+#i<|t_0.08|><|code_start|><|123|><|439|><|1074|><|705|><|1799|><|637|><|code_end|>
+#have<|t_0.16|><|code_start|><|1509|><|599|><|518|><|1170|><|552|><|1029|><|1267|><|864|><|419|><|143|><|1061|><|0|><|code_end|>
+#some<|t_0.16|><|code_start|><|619|><|400|><|1270|><|62|><|1370|><|1832|><|917|><|1661|><|167|><|269|><|1366|><|1508|><|code_end|>
+#critiques<|t_0.60|><|code_start|><|559|><|584|><|1163|><|1129|><|1313|><|1728|><|721|><|1146|><|1093|><|577|><|928|><|27|><|630|><|1080|><|1346|><|1337|><|320|><|1382|><|1175|><|1682|><|1556|><|990|><|1683|><|860|><|1721|><|110|><|786|><|376|><|1085|><|756|><|1523|><|234|><|1334|><|1506|><|1578|><|659|><|612|><|1108|><|1466|><|1647|><|308|><|1470|><|746|><|556|><|1061|><|code_end|>
+#about<|t_0.29|><|code_start|><|26|><|1649|><|545|><|1367|><|1263|><|1728|><|450|><|859|><|1434|><|497|><|1220|><|1285|><|179|><|755|><|1154|><|779|><|179|><|1229|><|1213|><|922|><|1774|><|1408|><|code_end|>
+#some<|t_0.23|><|code_start|><|986|><|28|><|1649|><|778|><|858|><|1519|><|1|><|18|><|26|><|1042|><|1174|><|1309|><|1499|><|1712|><|1692|><|1516|><|1574|><|code_end|>
+#of<|t_0.07|><|code_start|><|197|><|716|><|1039|><|1662|><|64|><|code_end|>
+#the<|t_0.08|><|code_start|><|1811|><|1568|><|569|><|886|><|1025|><|1374|><|code_end|>
+#gameplay<|t_0.48|><|code_start|><|1269|><|1092|><|933|><|1362|><|1762|><|1700|><|1675|><|215|><|781|><|1086|><|461|><|838|><|1022|><|759|><|649|><|1416|><|1004|><|551|><|909|><|787|><|343|><|830|><|1391|><|1040|><|1622|><|1779|><|1360|><|1231|><|1187|><|1317|><|76|><|997|><|989|><|978|><|737|><|189|><|code_end|>
+#aspects<|t_0.56|><|code_start|><|1423|><|797|><|1316|><|1222|><|147|><|719|><|1347|><|386|><|1390|><|1558|><|154|><|440|><|634|><|592|><|1097|><|1718|><|712|><|763|><|1118|><|1721|><|1311|><|868|><|580|><|362|><|1435|><|868|><|247|><|221|><|886|><|1145|><|1274|><|1284|><|457|><|1043|><|1459|><|1818|><|62|><|599|><|1035|><|62|><|1649|><|778|><|code_end|>
+#but<|t_0.20|><|code_start|><|780|><|1825|><|1681|><|1007|><|861|><|710|><|702|><|939|><|1669|><|1491|><|613|><|1739|><|823|><|1469|><|648|><|code_end|>
+#its<|t_0.09|><|code_start|><|92|><|688|><|1623|><|962|><|1670|><|527|><|599|><|code_end|>
+#still<|t_0.27|><|code_start|><|636|><|10|><|1217|><|344|><|713|><|957|><|823|><|154|><|1649|><|1286|><|508|><|214|><|1760|><|1250|><|456|><|1352|><|1368|><|921|><|615|><|5|><|code_end|>
+#really<|t_0.36|><|code_start|><|55|><|420|><|1008|><|1659|><|27|><|644|><|1266|><|617|><|761|><|1712|><|109|><|1465|><|1587|><|503|><|1541|><|619|><|197|><|1019|><|817|><|269|><|377|><|362|><|1381|><|507|><|1488|><|4|><|1695|><|code_end|>
+#enjoyable<|t_0.49|><|code_start|><|678|><|501|><|864|><|319|><|288|><|1472|><|1341|><|686|><|562|><|1463|><|619|><|1563|><|471|><|911|><|730|><|1811|><|1006|><|520|><|861|><|1274|><|125|><|1431|><|638|><|621|><|153|><|876|><|1770|><|437|><|987|><|1653|><|1109|><|898|><|1285|><|80|><|593|><|1709|><|843|><|code_end|>
+#and<|t_0.15|><|code_start|><|1285|><|987|><|303|><|1037|><|730|><|1164|><|502|><|120|><|1737|><|1655|><|1318|><|code_end|>
+#it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><|code_end|>
+#looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|>
+#lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>"""
+
+# TODO: tokenization is slow for some reason - here is pre-tokenized input
+suffix = [ 151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585, 152460, 153375, 151670, 198, 74455,
+          155808, 151669, 151799, 151873, 151863, 152446, 152372, 152204, 152728, 152229, 152470, 151970, 153413,
+          152419, 153334, 153289, 153374, 153199, 152040, 153260, 152721, 152680, 153297, 152419, 153248, 152400,
+          152691, 153368, 153437, 151670, 198, 1722, 155828, 151669, 152607, 152256, 152991, 152299, 152688, 153163,
+          153016, 152789, 153198, 152712, 151911, 153107, 152623, 152170, 152395, 152852, 152207, 152461, 153321,
+          153309, 151750, 152137, 153340, 152573, 152267, 153347, 151789, 152681, 153339, 151992, 152512, 151751,
+          152179, 153434, 153180, 152900, 153440, 152474, 153122, 153129, 151904, 152311, 151670, 198, 1499, 155791,
+          151669, 152276, 152454, 153354, 152544, 153204, 153272, 152708, 153433, 152319, 153226, 153043, 152325,
+          153267, 152622, 151670, 198, 4250, 155797, 151669, 153454, 153342, 151989, 152458, 153420, 152303, 152271,
+          152827, 153036, 153196, 151708, 153263, 152561, 153207, 152213, 152112, 153204, 151722, 152542, 151670, 198,
+          19789, 155796, 151669, 153353, 153182, 152345, 152471, 152477, 153014, 152002, 152191, 151734, 152312, 152810,
+          152237, 153224, 153169, 153224, 152244, 153387, 153404, 151670, 198, 16069, 155811, 151669, 152265, 151946,
+          151808, 152412, 152363, 152305, 153156, 152733, 152810, 153157, 152016, 152100, 152069, 153234, 152317,
+          152589, 152707, 153121, 153341, 152159, 152114, 153156, 153001, 153504, 153376, 152272, 152433, 152325,
+          151941, 151670, 198, 285, 155788, 151669, 152238, 152255, 153427, 152318, 153009, 152381, 152474, 152680,
+          152157, 153255, 152324, 151682, 151670, 198, 32955, 155804, 151669, 153490, 153419, 152364, 152405, 152682,
+          152206, 152078, 153369, 152725, 153193, 153027, 152946, 152488, 153070, 151883, 152890, 152489, 153144,
+          153375, 152358, 151685, 152494, 152117, 152740, 151670, 198, 37448, 480, 155840, 151669, 151902, 152720,
+          153377, 152027, 152378, 152821, 153207, 153459, 153028, 153068, 152507, 153255, 152158, 152921, 151958,
+          152609, 152748, 152822, 152286, 151714, 152730, 152377, 152353, 152470, 152606, 152162, 152186, 153071,
+          152244, 153118, 153375, 153018, 152712, 153098, 152976, 152336, 151843, 153202, 152297, 151736, 153380,
+          153502, 152702, 152115, 153181, 152735, 153277, 153457, 152393, 153112, 152595, 151670, 198, 19098, 155808,
+          151669, 152464, 153452, 152595, 153312, 151937, 151933, 153197, 152239, 153163, 152922, 153402, 152034,
+          152591, 153438, 152215, 151673, 152005, 151785, 152642, 151924, 153278, 151805, 151974, 153482, 152718,
+          152862, 153347, 151670, 198, 72, 155780, 151669, 151795, 152111, 152746, 152377, 153471, 152309, 151670, 198,
+          19016, 155788, 151669, 153181, 152271, 152190, 152842, 152224, 152701, 152939, 152536, 152091, 151815, 152733,
+          151672, 151670, 198, 14689, 155788, 151669, 152291, 152072, 152942, 151734, 153042, 153504, 152589, 153333,
+          151839, 151941, 153038, 153180, 151670, 198, 36996, 8303, 155832, 151669, 152231, 152256, 152835, 152801,
+          152985, 153400, 152393, 152818, 152765, 152249, 152600, 151699, 152302, 152752, 153018, 153009, 151992,
+          153054, 152847, 153354, 153228, 152662, 153355, 152532, 153393, 151782, 152458, 152048, 152757, 152428,
+          153195, 151906, 153006, 153178, 153250, 152331, 152284, 152780, 153138, 153319, 151980, 153142, 152418,
+          152228, 152733, 151670, 198, 9096, 155801, 151669, 151698, 153321, 152217, 153039, 152935, 153400, 152122,
+          152531, 153106, 152169, 152892, 152957, 151851, 152427, 152826, 152451, 151851, 152901, 152885, 152594,
+          153446, 153080, 151670, 198, 14689, 155795, 151669, 152658, 151700, 153321, 152450, 152530, 153191, 151673,
+          151690, 151698, 152714, 152846, 152981, 153171, 153384, 153364, 153188, 153246, 151670, 198, 1055, 155779,
+          151669, 151869, 152388, 152711, 153334, 151736, 151670, 198, 1782, 155780, 151669, 153483, 153240, 152241,
+          152558, 152697, 153046, 151670, 198, 5804, 1363, 155820, 151669, 152941, 152764, 152605, 153034, 153434,
+          153372, 153347, 151887, 152453, 152758, 152133, 152510, 152694, 152431, 152321, 153088, 152676, 152223,
+          152581, 152459, 152015, 152502, 153063, 152712, 153294, 153451, 153032, 152903, 152859, 152989, 151748,
+          152669, 152661, 152650, 152409, 151861, 151670, 198, 300, 7973, 155828, 151669, 153095, 152469, 152988,
+          152894, 151819, 152391, 153019, 152058, 153062, 153230, 151826, 152112, 152306, 152264, 152769, 153390,
+          152384, 152435, 152790, 153393, 152983, 152540, 152252, 152034, 153107, 152540, 151919, 151893, 152558,
+          152817, 152946, 152956, 152129, 152715, 153131, 153490, 151734, 152271, 152707, 151734, 153321, 152450,
+          151670, 198, 8088, 155792, 151669, 152452, 153497, 153353, 152679, 152533, 152382, 152374, 152611, 153341,
+          153163, 152285, 153411, 152495, 153141, 152320, 151670, 198, 1199, 155781, 151669, 151764, 152360, 153295,
+          152634, 153342, 152199, 152271, 151670, 198, 43366, 155799, 151669, 152308, 151682, 152889, 152016, 152385,
+          152629, 152495, 151826, 153321, 152958, 152180, 151886, 153432, 152922, 152128, 153024, 153040, 152593,
+          152287, 151677, 151670, 198, 53660, 155808, 151669, 151727, 152092, 152680, 153331, 151699, 152316, 152938,
+          152289, 152433, 153384, 151781, 153137, 153259, 152175, 153213, 152291, 151869, 152691, 152489, 151941,
+          152049, 152034, 153053, 152179, 153160, 151676, 153367, 151670, 198, 268, 4123, 480, 155821, 151669, 152350,
+          152173, 152536, 151991, 151960, 153144, 153013, 152358, 152234, 153135, 152291, 153235, 152143, 152583,
+          152402, 153483, 152678, 152192, 152533, 152946, 151797, 153103, 152310, 152293, 151825, 152548, 153442,
+          152109, 152659, 153325, 152781, 152570, 152957, 151752, 152265, 153381, 152515, 151670, 198, 437, 155787,
+          151669, 152957, 152659, 151975, 152709, 152402, 152836, 152174, 151792, 153409, 153327, 152990, 151670, 198,
+          275, 155781, 151669, 152520, 153038, 152067, 153273, 153185, 152265, 152974, 151670, 198, 94273, 155799,
+          151669, 152953, 152938, 153427, 152244, 151920, 153423, 152929, 152367, 153052, 152129, 152331, 152257,
+          152987, 152777, 153448, 152408, 151696, 152408, 152326, 152699, 151670, 198, 385, 16239, 155828, 151669,
+          152306, 152268, 153438, 153228, 152978, 152957, 153153, 153393, 152795, 152110, 152918, 152923, 152467,
+          152331, 153053, 153330, 151889, 153444, 152234, 152624, 151779, 152801, 152784, 152139, 152222, 152751,
+          152512, 153287, 153141, 153052, 151840, 152589, 152508, 153499, 152109, 152255, 151739, 152267, 152759,
+          153318, 153165, 153349, 151670, ]
+
+response = requests.post(
+    host_llm + "/completion",
+    json={
+        "prompt": [prefix + words, *suffix],
+        "n_predict": 1024,
+        "cache_prompt": True,
+        "return_tokens": True,
+        "samplers": ["top_k"],
+        "top_k": 16,
+        "seed": 1003,
+    }
+)
+
+response_json = response.json()
+
+#print(json.dumps(response_json, indent=4))
+#print(json.dumps(response_json["prompt"], indent=4).replace("\\n", "\n"))
+#print(json.dumps(response_json["timings"], indent=4))
+#print(json.dumps(response_json["tokens"], indent=4))
+
+codes = response_json["tokens"]
+
+codes = [t - 151672 for t in codes if t >= 151672 and t <= 155772]
+
+response = requests.post(
+    host_dec + "/embeddings",
+    json={
+        "input": [*codes],
+    }
+)
+
+response_json = response.json()
+
+#print(json.dumps(response_json, indent=4))
+
+# spectrogram
+embd = response_json[0]["embedding"]
+
+n_codes = len(embd)
+n_embd = len(embd[0])
+
+print('spectrogram generated: n_codes: %d, n_embd: %d' % (n_codes, n_embd))
+
+# post-process the spectrogram to convert to audio
+# TODO: see the tts.cpp:embd_to_audio() and implement it in Python
+print('converting to audio ...')
+print('TODO: see the tts.cpp:embd_to_audio() and implement it in Python')
diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp
new file mode 100644 (file)
index 0000000..7f36b80
--- /dev/null
@@ -0,0 +1,932 @@
+#include "arg.h"
+#include "common.h"
+#include "sampling.h"
+#include "log.h"
+#include "llama.h"
+
+#define _USE_MATH_DEFINES // For M_PI on MSVC
+
+#include <algorithm>
+#include <cmath>
+#include <cstdio>
+#include <fstream>
+#include <map>
+#include <regex>
+#include <string>
+#include <thread>
+#include <vector>
+
+//
+// Terminal utils
+//
+
+#define SQR(X)    ((X) * (X))
+#define UNCUBE(x) x < 48 ? 0 : x < 115 ? 1 : (x - 35) / 40
+
+/**
+ * Quantizes 24-bit RGB to xterm256 code range [16,256).
+ */
+static int rgb2xterm256(int r, int g, int b) {
+    unsigned char cube[] = {0, 0137, 0207, 0257, 0327, 0377};
+    int av, ir, ig, ib, il, qr, qg, qb, ql;
+    av = r * .299 + g * .587 + b * .114 + .5;
+    ql = (il = av > 238 ? 23 : (av - 3) / 10) * 10 + 8;
+    qr = cube[(ir = UNCUBE(r))];
+    qg = cube[(ig = UNCUBE(g))];
+    qb = cube[(ib = UNCUBE(b))];
+    if (SQR(qr - r) + SQR(qg - g) + SQR(qb - b) <=
+        SQR(ql - r) + SQR(ql - g) + SQR(ql - b))
+        return ir * 36 + ig * 6 + ib + 020;
+    return il + 0350;
+}
+
+static std::string set_xterm256_foreground(int r, int g, int b) {
+    int x = rgb2xterm256(r, g, b);
+    std::ostringstream oss;
+    oss << "\033[38;5;" << x << "m";
+    return oss.str();
+}
+
+const std::vector<std::string> k_colors = {
+    set_xterm256_foreground(220,   5,  12),
+    set_xterm256_foreground(232,  96,  28),
+    set_xterm256_foreground(241, 147,  45),
+    set_xterm256_foreground(246, 193,  65),
+    set_xterm256_foreground(247, 240,  86),
+    set_xterm256_foreground(144, 201, 135),
+    set_xterm256_foreground( 78, 178, 101),
+};
+
+static void print_usage(int, char ** argv) {
+    LOG("\nexample usage:\n");
+    LOG("\n    %s -m model.gguf -p \"Hello!\"\n", argv[0]);
+    LOG("\n");
+}
+
+struct wav_header {
+    char riff[4] = {'R', 'I', 'F', 'F'};
+    uint32_t chunk_size;
+    char wave[4] = {'W', 'A', 'V', 'E'};
+    char fmt[4] = {'f', 'm', 't', ' '};
+    uint32_t fmt_chunk_size = 16;
+    uint16_t audio_format = 1; // PCM
+    uint16_t num_channels = 1; // Mono
+    uint32_t sample_rate;
+    uint32_t byte_rate;
+    uint16_t block_align;
+    uint16_t bits_per_sample = 16;
+    char data[4] = {'d', 'a', 't', 'a'};
+    uint32_t data_size;
+};
+
+static void save_wav16(const std::string & fname, const std::vector<float> & data, int sample_rate) {
+    std::ofstream file(fname, std::ios::binary);
+    if (!file) {
+        LOG_ERR("%s: Failed to open file '%s' for writing", __func__, fname.c_str());
+        return;
+    }
+
+    wav_header header;
+    header.sample_rate = sample_rate;
+    header.byte_rate = header.sample_rate * header.num_channels * (header.bits_per_sample / 8);
+    header.block_align = header.num_channels * (header.bits_per_sample / 8);
+    header.data_size = data.size() * (header.bits_per_sample / 8);
+    header.chunk_size = 36 + header.data_size;
+
+    file.write(reinterpret_cast<const char*>(&header), sizeof(header));
+
+    for (const auto & sample : data) {
+        int16_t pcm_sample = static_cast<int16_t>(std::clamp(sample * 32767.0, -32768.0, 32767.0));
+        file.write(reinterpret_cast<const char*>(&pcm_sample), sizeof(pcm_sample));
+    }
+
+    file.close();
+}
+
+static void fill_hann_window(int length, bool periodic, float * output) {
+    int offset = -1;
+    if (periodic) {
+        offset = 0;
+    }
+    for (int i = 0; i < length; i++) {
+        output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
+    }
+}
+
+// very poor-man fft
+static void twiddle(float * real, float * imag, int k, int N) {
+    float angle = 2 * M_PI * k / N;
+    *real = cos(angle);
+    *imag = sin(angle);
+}
+
+static void irfft(int n, const float * inp_cplx, float * out_real) {
+    int N = n / 2 + 1;
+
+    std::vector<float> real_input(N);
+    std::vector<float> imag_input(N);
+    for (int i = 0; i < N; ++i) {
+        real_input[i] = inp_cplx[2 * i];
+        imag_input[i] = inp_cplx[2 * i + 1];
+    }
+
+    std::vector<float> real_output(n);
+    std::vector<float> imag_output(n);
+
+    for (int k = 0; k < n; ++k) {
+        real_output[k] = 0.0f;
+        imag_output[k] = 0.0f;
+        for (int m = 0; m < N; ++m) {
+            float twiddle_real;
+            float twiddle_imag;
+
+            twiddle(&twiddle_real, &twiddle_imag, k * m, n);
+
+            real_output[k] += real_input[m] * twiddle_real - imag_input[m] * twiddle_imag;
+            imag_output[k] += real_input[m] * twiddle_imag + imag_input[m] * twiddle_real;
+        }
+    }
+
+    for (int i = 0; i < n; ++i) {
+        out_real[i] = real_output[i] / N;
+    }
+}
+
+//
+//  y = torch.nn.functional.fold(
+//       data, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
+//  )[:, 0, 0, pad:-pad]
+//
+// data.shape =  torch.Size([1, 1280, 261])
+// output_size =  84480
+// win_length =  1280
+// hop_length =  320
+// pad =  480
+//
+static void fold(const std::vector<float> & data, int64_t n_out, int64_t n_win, int64_t n_hop, int64_t n_pad, std::vector<float> & output) {
+    int64_t output_height = n_out;
+    int64_t kernel_w = n_win;
+    int64_t stride_w = n_hop;
+    int64_t width    = n_out;
+
+    output.resize(width, 0.0f);
+
+    int64_t col_idx = 0;
+    for (int64_t w_col = 0; w_col < width; ++w_col) {
+        int64_t start = w_col * stride_w - n_pad;
+        int64_t end   = start + kernel_w;
+
+        for (int64_t w_im = start; w_im < end; ++w_im) {
+            if (w_im >= 0 && w_im < output_height && col_idx < (int64_t) data.size()) {
+                output[w_im] += data[col_idx];
+            }
+            col_idx++;
+        }
+    }
+
+    output.resize(n_out - 2 * n_pad);
+}
+
+// TODO: not optimized at all
+static std::vector<float> embd_to_audio(
+        const float * embd,
+        const int n_codes,
+        const int n_embd,
+        const int n_thread) {
+    const int n_fft = 1280;
+    const int n_hop = 320;
+    const int n_win = 1280;
+    const int n_pad = (n_win - n_hop)/2;
+    const int n_out = (n_codes - 1)*n_hop + n_win;
+
+    std::vector<float> hann(n_fft);
+
+    fill_hann_window(hann.size(), true, hann.data());
+
+    int n_spec = n_embd*n_codes;
+
+    std::vector<float> E (n_spec);
+    std::vector<float> S (n_spec);
+    std::vector<float> ST(n_spec);
+
+    for (int l = 0; l < n_codes; ++l) {
+        for (int k = 0; k < n_embd; ++k) {
+            E[k*n_codes + l] = embd[l*n_embd + k];
+        }
+    }
+
+    for (int k = 0; k < n_embd/2; ++k) {
+        for (int l = 0; l < n_codes; ++l) {
+            float mag = E[(k           )*n_codes + l];
+            float phi = E[(k + n_embd/2)*n_codes + l];
+
+            mag = exp(mag);
+
+            if (mag > 1e2) {
+                mag = 1e2;
+            }
+            S[2*(k*n_codes + l) + 0] = mag*cosf(phi);
+            S[2*(k*n_codes + l) + 1] = mag*sinf(phi);
+        }
+    }
+
+    for (int l = 0; l < n_codes; ++l) {
+        for (int k = 0; k < n_embd/2; ++k) {
+            ST[l*n_embd + 2*k + 0] = S[2*(k*n_codes + l) + 0];
+            ST[l*n_embd + 2*k + 1] = S[2*(k*n_codes + l) + 1];
+        }
+    }
+
+    std::vector<float> res  (n_codes*n_fft);
+    std::vector<float> hann2(n_codes*n_fft);
+
+    std::vector<std::thread> workers(n_thread);
+    for (int i = 0; i < n_thread; ++i) {
+        workers[i] = std::thread([&, i]() {
+            for (int l = i; l < n_codes; l += n_thread) {
+                irfft(n_fft, ST.data() + l*n_embd, res.data() + l*n_fft);
+                for (int j = 0; j < n_fft; ++j) {
+                    res  [l*n_fft + j] *= hann[j];
+                    hann2[l*n_fft + j]  = hann[j] * hann[j];
+                }
+            }
+        });
+    }
+    for (int i = 0; i < n_thread; ++i) {
+        workers[i].join();
+    }
+
+    std::vector<float> audio;
+    std::vector<float> env;
+
+    fold(res,   n_out, n_win, n_hop, n_pad, audio);
+    fold(hann2, n_out, n_win, n_hop, n_pad, env); // TODO: can be done once
+
+    for (size_t i = 0; i < audio.size(); ++i) {
+        audio[i] /= env[i];
+    }
+
+    return audio;
+}
+
+static const std::map<int, std::string> ones = {
+    {0, "zero"}, {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"},
+    {5, "five"}, {6, "six"}, {7, "seven"}, {8, "eight"}, {9, "nine"},
+    {10, "ten"}, {11, "eleven"}, {12, "twelve"}, {13, "thirteen"}, {14, "fourteen"},
+    {15, "fifteen"}, {16, "sixteen"}, {17, "seventeen"}, {18, "eighteen"}, {19, "nineteen"}
+};
+
+static const std::map<int, std::string> tens = {
+    {2, "twenty"}, {3, "thirty"}, {4, "forty"}, {5, "fifty"},
+    {6, "sixty"}, {7, "seventy"}, {8, "eighty"}, {9, "ninety"}
+};
+
+// Convert a number less than 1000 to words
+static std::string convert_less_than_thousand(int num) {
+    std::string result;
+
+    if (num >= 100) {
+        result += ones.at(num / 100) + " hundred ";
+        num %= 100;
+    }
+
+    if (num >= 20) {
+        result += tens.at(num / 10);
+        if (num % 10 > 0) {
+            result += "-" + ones.at(num % 10);
+        }
+    } else if (num > 0) {
+        result += ones.at(num);
+    }
+
+    return result;
+}
+
+static std::string number_to_words(const std::string & number_str) {
+    try {
+        size_t decimal_pos = number_str.find('.');
+        std::string integer_part = number_str.substr(0, decimal_pos);
+
+        int int_number = std::stoi(integer_part);
+        std::string result;
+
+        if (int_number == 0) {
+            result = "zero";
+        } else {
+            if (int_number >= 1000000000) {
+                int billions = int_number / 1000000000;
+                result += convert_less_than_thousand(billions) + " billion ";
+                int_number %= 1000000000;
+            }
+
+            if (int_number >= 1000000) {
+                int millions = int_number / 1000000;
+                result += convert_less_than_thousand(millions) + " million ";
+                int_number %= 1000000;
+            }
+
+            if (int_number >= 1000) {
+                int thousands = int_number / 1000;
+                result += convert_less_than_thousand(thousands) + " thousand ";
+                int_number %= 1000;
+            }
+
+            if (int_number > 0) {
+                result += convert_less_than_thousand(int_number);
+            }
+        }
+
+        // Handle decimal part
+        if (decimal_pos != std::string::npos) {
+            result += " point";
+            std::string decimal_part = number_str.substr(decimal_pos + 1);
+            for (char digit : decimal_part) {
+                result += " " + ones.at(digit - '0');
+            }
+        }
+
+        return result;
+    } catch (const std::exception& e) {
+        // Skip if fails
+        return " ";
+    }
+}
+
+static std::string replace_numbers_with_words(const std::string & input_text) {
+    std::regex number_pattern(R"(\d+(\.\d+)?)");
+    std::string result;
+    auto it = std::sregex_iterator(input_text.begin(), input_text.end(), number_pattern);
+    auto end = std::sregex_iterator();
+
+    size_t last_pos = 0;
+    for (std::sregex_iterator i = it; i != end; ++i) {
+        const std::smatch& match = *i;
+        result.append(input_text, last_pos, match.position() - last_pos);
+        result.append(number_to_words(match.str()));
+        last_pos = match.position() + match.length();
+    }
+    result.append(input_text, last_pos);
+
+    return result;
+}
+
+// Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39
+static std::string process_text(const std::string & text) {
+
+    // For now I skipped text romanization as I am unsure how to handle
+    // uroman and MeCab implementations in C++
+    // maybe something like https://github.com/anyascii/anyascii/ could work.
+    // currently only English would be supported in this function
+
+    std::string processed_text = replace_numbers_with_words(text);
+
+    std::transform(processed_text.begin(), processed_text.end(),
+                  processed_text.begin(), ::tolower);
+
+    std::regex special_chars(R"([-_/,\.\\])");
+    processed_text = std::regex_replace(processed_text, special_chars, " ");
+
+    std::regex non_alpha(R"([^a-z\s])");
+    processed_text = std::regex_replace(processed_text, non_alpha, "");
+
+    std::regex multiple_spaces(R"(\s+)");
+    processed_text = std::regex_replace(processed_text, multiple_spaces, " ");
+
+    processed_text = std::regex_replace(processed_text, std::regex(R"(^\s+|\s+$)"), "");
+
+    /*
+        Replace spaces with the separator token same as in line 365
+
+        for (auto & c : prompt_user) {
+        if (c == ' ') {
+            prompt_clean += "<|text_sep|>";
+    */
+    processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), "<|text_sep|>");
+
+    return processed_text;
+}
+
+static void prompt_add(llama_tokens & prompt, llama_token token) {
+    prompt.push_back(token);
+}
+
+static void prompt_add(llama_tokens & prompt, const llama_tokens & tokens) {
+    prompt.insert(prompt.end(), tokens.begin(), tokens.end());
+}
+
+static void prompt_add(llama_tokens & prompt, const llama_model * model, const std::string & txt, bool add_special, bool parse_special) {
+    auto tmp = common_tokenize(model, txt, add_special, parse_special);
+    prompt_add(prompt, tmp);
+}
+
+static void prompt_init(llama_tokens & prompt, const llama_model * model) {
+    prompt.clear();
+
+    prompt_add(prompt, model, "<|im_start|>\n", true, true);
+}
+
+int main(int argc, char ** argv) {
+    common_params params;
+
+    params.prompt = "";
+
+    params.n_predict = 4096;
+    params.n_batch   = 8192;
+    params.n_ctx     = 8192;
+
+    params.sampling.top_k = 4;
+    params.sampling.samplers = { COMMON_SAMPLER_TYPE_TOP_K, };
+
+    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_TTS, print_usage)) {
+        return 1;
+    }
+
+    const int n_parallel = params.n_parallel;
+    const int n_predict  = params.n_predict;
+
+    common_init();
+
+    // init LLM
+
+    llama_backend_init();
+    llama_numa_init(params.numa);
+
+    llama_model * model_ttc = NULL; // text-to-codes
+    llama_model * model_cts = NULL; // codes-to-speech
+
+    llama_context * ctx_ttc = NULL;
+    llama_context * ctx_cts = NULL;
+
+    common_init_result llama_init_ttc = common_init_from_params(params);
+    model_ttc = llama_init_ttc.model;
+    ctx_ttc = llama_init_ttc.context;
+
+    // TODO: refactor in a common struct
+    params.model     = params.vocoder.model;
+    params.model_url = params.vocoder.model_url;
+    params.hf_repo   = params.vocoder.hf_repo;
+    params.hf_file   = params.vocoder.hf_file;
+
+    params.embedding = true;
+
+    common_init_result llama_init_cts = common_init_from_params(params);
+    model_cts = llama_init_cts.model;
+    ctx_cts = llama_init_cts.context;
+
+    std::vector<common_sampler *> smpl(n_parallel);
+    for (int i = 0; i < n_parallel; ++i) {
+        params.sampling.no_perf = (i != 0);
+        params.sampling.seed = params.sampling.seed + 1;
+
+        smpl[i] = common_sampler_init(model_ttc, params.sampling);
+    }
+
+    LOG_INF("sampler seed: %u\n",     common_sampler_get_seed(smpl[0]));
+    LOG_INF("sampler params: \n%s\n", params.sampling.print().c_str());
+    LOG_INF("sampler chain: %s\n",    common_sampler_print(smpl[0]).c_str());
+
+    LOG_INF("%s: loading done\n", __func__);
+
+    const auto t_main_start = ggml_time_us();
+
+    std::vector<llama_token> codes;
+
+    // process prompt and generate voice codes
+    {
+        LOG_INF("%s: constructing prompt ..\n", __func__);
+
+        std::vector<llama_token> prompt_inp;
+
+        prompt_init(prompt_inp, model_ttc);
+
+        prompt_add(prompt_inp, model_ttc, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true);
+
+        // convert the input text into the necessary format expected by OuteTTS
+        {
+            std::string prompt_clean = process_text(params.prompt);
+
+            LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
+
+            prompt_add(prompt_inp, model_ttc, prompt_clean, false, true);
+        }
+
+        prompt_add(prompt_inp, model_ttc, "<|text_end|>\n", false, true);
+
+        // disabled to save time on tokenizing each time
+        // TODO: load voices from the json files
+#if 0
+        const std::string voice_data = R"(<|audio_start|>
+the<|t_0.08|><|code_start|><|257|><|740|><|636|><|913|><|788|><|1703|><|code_end|>
+overall<|t_0.36|><|code_start|><|127|><|201|><|191|><|774|><|700|><|532|><|1056|><|557|><|798|><|298|><|1741|><|747|><|1662|><|1617|><|1702|><|1527|><|368|><|1588|><|1049|><|1008|><|1625|><|747|><|1576|><|728|><|1019|><|1696|><|1765|><|code_end|>
+package<|t_0.56|><|code_start|><|935|><|584|><|1319|><|627|><|1016|><|1491|><|1344|><|1117|><|1526|><|1040|><|239|><|1435|><|951|><|498|><|723|><|1180|><|535|><|789|><|1649|><|1637|><|78|><|465|><|1668|><|901|><|595|><|1675|><|117|><|1009|><|1667|><|320|><|840|><|79|><|507|><|1762|><|1508|><|1228|><|1768|><|802|><|1450|><|1457|><|232|><|639|><|code_end|>
+from<|t_0.19|><|code_start|><|604|><|782|><|1682|><|872|><|1532|><|1600|><|1036|><|1761|><|647|><|1554|><|1371|><|653|><|1595|><|950|><|code_end|>
+just<|t_0.25|><|code_start|><|1782|><|1670|><|317|><|786|><|1748|><|631|><|599|><|1155|><|1364|><|1524|><|36|><|1591|><|889|><|1535|><|541|><|440|><|1532|><|50|><|870|><|code_end|>
+two<|t_0.24|><|code_start|><|1681|><|1510|><|673|><|799|><|805|><|1342|><|330|><|519|><|62|><|640|><|1138|><|565|><|1552|><|1497|><|1552|><|572|><|1715|><|1732|><|code_end|>
+people<|t_0.39|><|code_start|><|593|><|274|><|136|><|740|><|691|><|633|><|1484|><|1061|><|1138|><|1485|><|344|><|428|><|397|><|1562|><|645|><|917|><|1035|><|1449|><|1669|><|487|><|442|><|1484|><|1329|><|1832|><|1704|><|600|><|761|><|653|><|269|><|code_end|>
+is<|t_0.16|><|code_start|><|566|><|583|><|1755|><|646|><|1337|><|709|><|802|><|1008|><|485|><|1583|><|652|><|10|><|code_end|>
+pretty<|t_0.32|><|code_start|><|1818|><|1747|><|692|><|733|><|1010|><|534|><|406|><|1697|><|1053|><|1521|><|1355|><|1274|><|816|><|1398|><|211|><|1218|><|817|><|1472|><|1703|><|686|><|13|><|822|><|445|><|1068|><|code_end|>
+remarkable<|t_0.68|><|code_start|><|230|><|1048|><|1705|><|355|><|706|><|1149|><|1535|><|1787|><|1356|><|1396|><|835|><|1583|><|486|><|1249|><|286|><|937|><|1076|><|1150|><|614|><|42|><|1058|><|705|><|681|><|798|><|934|><|490|><|514|><|1399|><|572|><|1446|><|1703|><|1346|><|1040|><|1426|><|1304|><|664|><|171|><|1530|><|625|><|64|><|1708|><|1830|><|1030|><|443|><|1509|><|1063|><|1605|><|1785|><|721|><|1440|><|923|><|code_end|>
+sure<|t_0.36|><|code_start|><|792|><|1780|><|923|><|1640|><|265|><|261|><|1525|><|567|><|1491|><|1250|><|1730|><|362|><|919|><|1766|><|543|><|1|><|333|><|113|><|970|><|252|><|1606|><|133|><|302|><|1810|><|1046|><|1190|><|1675|><|code_end|>
+i<|t_0.08|><|code_start|><|123|><|439|><|1074|><|705|><|1799|><|637|><|code_end|>
+have<|t_0.16|><|code_start|><|1509|><|599|><|518|><|1170|><|552|><|1029|><|1267|><|864|><|419|><|143|><|1061|><|0|><|code_end|>
+some<|t_0.16|><|code_start|><|619|><|400|><|1270|><|62|><|1370|><|1832|><|917|><|1661|><|167|><|269|><|1366|><|1508|><|code_end|>
+critiques<|t_0.60|><|code_start|><|559|><|584|><|1163|><|1129|><|1313|><|1728|><|721|><|1146|><|1093|><|577|><|928|><|27|><|630|><|1080|><|1346|><|1337|><|320|><|1382|><|1175|><|1682|><|1556|><|990|><|1683|><|860|><|1721|><|110|><|786|><|376|><|1085|><|756|><|1523|><|234|><|1334|><|1506|><|1578|><|659|><|612|><|1108|><|1466|><|1647|><|308|><|1470|><|746|><|556|><|1061|><|code_end|>
+about<|t_0.29|><|code_start|><|26|><|1649|><|545|><|1367|><|1263|><|1728|><|450|><|859|><|1434|><|497|><|1220|><|1285|><|179|><|755|><|1154|><|779|><|179|><|1229|><|1213|><|922|><|1774|><|1408|><|code_end|>
+some<|t_0.23|><|code_start|><|986|><|28|><|1649|><|778|><|858|><|1519|><|1|><|18|><|26|><|1042|><|1174|><|1309|><|1499|><|1712|><|1692|><|1516|><|1574|><|code_end|>
+of<|t_0.07|><|code_start|><|197|><|716|><|1039|><|1662|><|64|><|code_end|>
+the<|t_0.08|><|code_start|><|1811|><|1568|><|569|><|886|><|1025|><|1374|><|code_end|>
+gameplay<|t_0.48|><|code_start|><|1269|><|1092|><|933|><|1362|><|1762|><|1700|><|1675|><|215|><|781|><|1086|><|461|><|838|><|1022|><|759|><|649|><|1416|><|1004|><|551|><|909|><|787|><|343|><|830|><|1391|><|1040|><|1622|><|1779|><|1360|><|1231|><|1187|><|1317|><|76|><|997|><|989|><|978|><|737|><|189|><|code_end|>
+aspects<|t_0.56|><|code_start|><|1423|><|797|><|1316|><|1222|><|147|><|719|><|1347|><|386|><|1390|><|1558|><|154|><|440|><|634|><|592|><|1097|><|1718|><|712|><|763|><|1118|><|1721|><|1311|><|868|><|580|><|362|><|1435|><|868|><|247|><|221|><|886|><|1145|><|1274|><|1284|><|457|><|1043|><|1459|><|1818|><|62|><|599|><|1035|><|62|><|1649|><|778|><|code_end|>
+but<|t_0.20|><|code_start|><|780|><|1825|><|1681|><|1007|><|861|><|710|><|702|><|939|><|1669|><|1491|><|613|><|1739|><|823|><|1469|><|648|><|code_end|>
+its<|t_0.09|><|code_start|><|92|><|688|><|1623|><|962|><|1670|><|527|><|599|><|code_end|>
+still<|t_0.27|><|code_start|><|636|><|10|><|1217|><|344|><|713|><|957|><|823|><|154|><|1649|><|1286|><|508|><|214|><|1760|><|1250|><|456|><|1352|><|1368|><|921|><|615|><|5|><|code_end|>
+really<|t_0.36|><|code_start|><|55|><|420|><|1008|><|1659|><|27|><|644|><|1266|><|617|><|761|><|1712|><|109|><|1465|><|1587|><|503|><|1541|><|619|><|197|><|1019|><|817|><|269|><|377|><|362|><|1381|><|507|><|1488|><|4|><|1695|><|code_end|>
+enjoyable<|t_0.49|><|code_start|><|678|><|501|><|864|><|319|><|288|><|1472|><|1341|><|686|><|562|><|1463|><|619|><|1563|><|471|><|911|><|730|><|1811|><|1006|><|520|><|861|><|1274|><|125|><|1431|><|638|><|621|><|153|><|876|><|1770|><|437|><|987|><|1653|><|1109|><|898|><|1285|><|80|><|593|><|1709|><|843|><|code_end|>
+and<|t_0.15|><|code_start|><|1285|><|987|><|303|><|1037|><|730|><|1164|><|502|><|120|><|1737|><|1655|><|1318|><|code_end|>
+it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><|code_end|>
+looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|>
+lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)";
+
+        auto tmp = common_tokenize(model_ttc, voice_data, false, true);
+        printf("\n\n");
+        for (int i = 0; i < tmp.size(); ++i) {
+            printf("%d, ", tmp[i]);
+        }
+        printf("\n\n");
+#else
+        prompt_add(prompt_inp, llama_tokens {
+            151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585,
+            152460, 153375, 151670, 198, 74455, 155808, 151669, 151799,
+            151873, 151863, 152446, 152372, 152204, 152728, 152229, 152470,
+            151970, 153413, 152419, 153334, 153289, 153374, 153199, 152040,
+            153260, 152721, 152680, 153297, 152419, 153248, 152400, 152691,
+            153368, 153437, 151670, 198, 1722, 155828, 151669, 152607,
+            152256, 152991, 152299, 152688, 153163, 153016, 152789, 153198,
+            152712, 151911, 153107, 152623, 152170, 152395, 152852, 152207,
+            152461, 153321, 153309, 151750, 152137, 153340, 152573, 152267,
+            153347, 151789, 152681, 153339, 151992, 152512, 151751, 152179,
+            153434, 153180, 152900, 153440, 152474, 153122, 153129, 151904,
+            152311, 151670, 198, 1499, 155791, 151669, 152276, 152454,
+            153354, 152544, 153204, 153272, 152708, 153433, 152319, 153226,
+            153043, 152325, 153267, 152622, 151670, 198, 4250, 155797,
+            151669, 153454, 153342, 151989, 152458, 153420, 152303, 152271,
+            152827, 153036, 153196, 151708, 153263, 152561, 153207, 152213,
+            152112, 153204, 151722, 152542, 151670, 198, 19789, 155796,
+            151669, 153353, 153182, 152345, 152471, 152477, 153014, 152002,
+            152191, 151734, 152312, 152810, 152237, 153224, 153169, 153224,
+            152244, 153387, 153404, 151670, 198, 16069, 155811, 151669,
+            152265, 151946, 151808, 152412, 152363, 152305, 153156, 152733,
+            152810, 153157, 152016, 152100, 152069, 153234, 152317, 152589,
+            152707, 153121, 153341, 152159, 152114, 153156, 153001, 153504,
+            153376, 152272, 152433, 152325, 151941, 151670, 198, 285,
+            155788, 151669, 152238, 152255, 153427, 152318, 153009, 152381,
+            152474, 152680, 152157, 153255, 152324, 151682, 151670, 198,
+            32955, 155804, 151669, 153490, 153419, 152364, 152405, 152682,
+            152206, 152078, 153369, 152725, 153193, 153027, 152946, 152488,
+            153070, 151883, 152890, 152489, 153144, 153375, 152358, 151685,
+            152494, 152117, 152740, 151670, 198, 37448, 480, 155840, 151669,
+            151902, 152720, 153377, 152027, 152378, 152821, 153207, 153459,
+            153028, 153068, 152507, 153255, 152158, 152921, 151958, 152609,
+            152748, 152822, 152286, 151714, 152730, 152377, 152353, 152470,
+            152606, 152162, 152186, 153071, 152244, 153118, 153375, 153018,
+            152712, 153098, 152976, 152336, 151843, 153202, 152297, 151736,
+            153380, 153502, 152702, 152115, 153181, 152735, 153277, 153457,
+            152393, 153112, 152595, 151670, 198, 19098, 155808, 151669,
+            152464, 153452, 152595, 153312, 151937, 151933, 153197, 152239,
+            153163, 152922, 153402, 152034, 152591, 153438, 152215, 151673,
+            152005, 151785, 152642, 151924, 153278, 151805, 151974, 153482,
+            152718, 152862, 153347, 151670, 198, 72, 155780, 151669, 151795,
+            152111, 152746, 152377, 153471, 152309, 151670, 198, 19016,
+            155788, 151669, 153181, 152271, 152190, 152842, 152224, 152701,
+            152939, 152536, 152091, 151815, 152733, 151672, 151670, 198,
+            14689, 155788, 151669, 152291, 152072, 152942, 151734, 153042,
+            153504, 152589, 153333, 151839, 151941, 153038, 153180, 151670,
+            198, 36996, 8303, 155832, 151669, 152231, 152256, 152835,
+            152801, 152985, 153400, 152393, 152818, 152765, 152249, 152600,
+            151699, 152302, 152752, 153018, 153009, 151992, 153054, 152847,
+            153354, 153228, 152662, 153355, 152532, 153393, 151782, 152458,
+            152048, 152757, 152428, 153195, 151906, 153006, 153178, 153250,
+            152331, 152284, 152780, 153138, 153319, 151980, 153142, 152418,
+            152228, 152733, 151670, 198, 9096, 155801, 151669, 151698,
+            153321, 152217, 153039, 152935, 153400, 152122, 152531, 153106,
+            152169, 152892, 152957, 151851, 152427, 152826, 152451, 151851,
+            152901, 152885, 152594, 153446, 153080, 151670, 198, 14689,
+            155795, 151669, 152658, 151700, 153321, 152450, 152530, 153191,
+            151673, 151690, 151698, 152714, 152846, 152981, 153171, 153384,
+            153364, 153188, 153246, 151670, 198, 1055, 155779, 151669,
+            151869, 152388, 152711, 153334, 151736, 151670, 198, 1782,
+            155780, 151669, 153483, 153240, 152241, 152558, 152697, 153046,
+            151670, 198, 5804, 1363, 155820, 151669, 152941, 152764, 152605,
+            153034, 153434, 153372, 153347, 151887, 152453, 152758, 152133,
+            152510, 152694, 152431, 152321, 153088, 152676, 152223, 152581,
+            152459, 152015, 152502, 153063, 152712, 153294, 153451, 153032,
+            152903, 152859, 152989, 151748, 152669, 152661, 152650, 152409,
+            151861, 151670, 198, 300, 7973, 155828, 151669, 153095, 152469,
+            152988, 152894, 151819, 152391, 153019, 152058, 153062, 153230,
+            151826, 152112, 152306, 152264, 152769, 153390, 152384, 152435,
+            152790, 153393, 152983, 152540, 152252, 152034, 153107, 152540,
+            151919, 151893, 152558, 152817, 152946, 152956, 152129, 152715,
+            153131, 153490, 151734, 152271, 152707, 151734, 153321, 152450,
+            151670, 198, 8088, 155792, 151669, 152452, 153497, 153353,
+            152679, 152533, 152382, 152374, 152611, 153341, 153163, 152285,
+            153411, 152495, 153141, 152320, 151670, 198, 1199, 155781,
+            151669, 151764, 152360, 153295, 152634, 153342, 152199, 152271,
+            151670, 198, 43366, 155799, 151669, 152308, 151682, 152889,
+            152016, 152385, 152629, 152495, 151826, 153321, 152958, 152180,
+            151886, 153432, 152922, 152128, 153024, 153040, 152593, 152287,
+            151677, 151670, 198, 53660, 155808, 151669, 151727, 152092,
+            152680, 153331, 151699, 152316, 152938, 152289, 152433, 153384,
+            151781, 153137, 153259, 152175, 153213, 152291, 151869, 152691,
+            152489, 151941, 152049, 152034, 153053, 152179, 153160, 151676,
+            153367, 151670, 198, 268, 4123, 480, 155821, 151669, 152350,
+            152173, 152536, 151991, 151960, 153144, 153013, 152358, 152234,
+            153135, 152291, 153235, 152143, 152583, 152402, 153483, 152678,
+            152192, 152533, 152946, 151797, 153103, 152310, 152293, 151825,
+            152548, 153442, 152109, 152659, 153325, 152781, 152570, 152957,
+            151752, 152265, 153381, 152515, 151670, 198, 437, 155787,
+            151669, 152957, 152659, 151975, 152709, 152402, 152836, 152174,
+            151792, 153409, 153327, 152990, 151670, 198, 275, 155781,
+            151669, 152520, 153038, 152067, 153273, 153185, 152265, 152974,
+            151670, 198, 94273, 155799, 151669, 152953, 152938, 153427,
+            152244, 151920, 153423, 152929, 152367, 153052, 152129, 152331,
+            152257, 152987, 152777, 153448, 152408, 151696, 152408, 152326,
+            152699, 151670, 198, 385, 16239, 155828, 151669, 152306, 152268,
+            153438, 153228, 152978, 152957, 153153, 153393, 152795, 152110,
+            152918, 152923, 152467, 152331, 153053, 153330, 151889, 153444,
+            152234, 152624, 151779, 152801, 152784, 152139, 152222, 152751,
+            152512, 153287, 153141, 153052, 151840, 152589, 152508, 153499,
+            152109, 152255, 151739, 152267, 152759, 153318, 153165, 153349,
+            151670,});
+#endif
+
+        // print the prompt token-by-token
+
+        LOG("\n");
+
+        for (auto id : prompt_inp) {
+            LOG("%s", common_token_to_piece(ctx_ttc, id).c_str());
+        }
+
+        LOG_INF("%s: prompt size: %d\n", __func__, (int) prompt_inp.size());
+
+        LOG("\n");
+
+        // create a llama_batch
+        // we use this object to submit token data for decoding
+        llama_batch batch = llama_batch_init(std::max(prompt_inp.size(), (size_t) n_parallel), 0, n_parallel);
+
+        std::vector<llama_seq_id> seq_ids(n_parallel, 0);
+        for (int32_t i = 0; i < n_parallel; ++i) {
+            seq_ids[i] = i;
+        }
+
+        // evaluate the initial prompt
+        for (size_t i = 0; i < prompt_inp.size(); ++i) {
+            common_batch_add(batch, prompt_inp[i], i, seq_ids, false);
+        }
+        GGML_ASSERT(batch.n_tokens == (int) prompt_inp.size());
+
+        // llama_decode will output logits only for the last token of the prompt
+        batch.logits[batch.n_tokens - 1] = true;
+
+        if (llama_decode(ctx_ttc, batch) != 0) {
+            LOG_ERR("%s: llama_decode() failed\n", __func__);
+            return 1;
+        }
+
+        if (n_parallel > 1) {
+            LOG_INF("\n\n%s: generating %d sequences ...\n", __func__, n_parallel);
+        }
+
+        llama_synchronize(ctx_ttc);
+
+        LOG_INF("%s: time for prompt: %.3f ms\n\n", __func__, (ggml_time_us() - t_main_start) / 1000.0f);
+
+        const auto t_dec_start = ggml_time_us();
+
+        // main loop
+
+        // remember the batch index of the last token for each parallel sequence
+        // we need this to determine which logits to sample from
+        std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1);
+
+        int n_past   = batch.n_tokens;
+        int n_decode = 0;
+
+        while (n_decode <= n_predict) {
+            // prepare the next batch
+            common_batch_clear(batch);
+
+            // sample the next token for each parallel sequence / stream
+            for (int32_t i = 0; i < n_parallel; ++i) {
+                if (i_batch[i] < 0) {
+                    // the stream has already finished
+                    continue;
+                }
+
+                const llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
+
+                common_sampler_accept(smpl[i], new_token_id, true);
+
+                codes.push_back(new_token_id);
+
+                const auto * cands = common_sampler_get_candidates(smpl[i]);
+
+                // is it an end of generation? -> mark the stream as finished
+                if (llama_token_is_eog(model_ttc, new_token_id) || n_decode == n_predict) {
+                    std::string reason;
+                    if (llama_token_is_eog(model_ttc, new_token_id)) {
+                        reason = "eos";
+                    } else {
+                        reason = "n_predict";
+                    }
+
+                    i_batch[i] = -1;
+
+                    LOG("\n");
+                    if (n_parallel > 1) {
+                        LOG_CNT("\n");
+                        LOG_INF("%s: stream %d finished at n_past = %d, reason = '%s'\n", __func__, i, n_past, reason.c_str());
+                    }
+
+                    continue;
+                }
+
+                {
+                    const float p = cands->data[cands->selected].p;
+
+                    const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) ((3*p)*float(k_colors.size()))));
+
+                    LOG_CNT("%s%d%s", k_colors[col].c_str(), i, "\033[0m");
+                    //LOG_CNT("%d", i);
+                }
+
+                i_batch[i] = batch.n_tokens;
+
+                // push this new token for next evaluation
+                common_batch_add(batch, new_token_id, n_past, { i }, true);
+            }
+
+            // all streams are finished
+            if (batch.n_tokens == 0) {
+                break;
+            }
+
+            n_decode += 1;
+            n_past += 1;
+
+            // evaluate the current batch with the transformer model
+            if (llama_decode(ctx_ttc, batch)) {
+                LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
+                return 1;
+            }
+        }
+
+        llama_batch_free(batch);
+
+        LOG("\n");
+        LOG_INF("%s: time for decoder:       %.3f ms\n", __func__, (ggml_time_us() - t_dec_start) / 1000.0f);
+    }
+
+    common_perf_print(ctx_ttc, smpl[0]);
+
+    //std::vector<llama_token> codes = {198, 88225, 155856, 151669, 152205,
+    //    153064, 152537, 153421, 153209, 152524, 151689, 152993, 152438, 152695,
+    //    153091, 152945, 152829, 152534, 152934, 153020, 151997, 152263, 153010,
+    //    153146, 152399, 153208, 152496, 151793, 152848, 152263, 152571, 153286,
+    //    152227, 153300, 152934, 152263, 153208, 152263, 152965, 152430, 152296,
+    //    153146, 152920, 152376, 152556, 153363, 151775, 152044, 152972, 152690,
+    //    153379, 152368, 152233, 153422, 152490, 151996, 152022, 151694, 152061,
+    //    153238, 152539, 153356, 152640, 153021, 153123, 151962, 153094, 151670,
+    //    198, 20339, 13189, 155824, 151669, 152070, 152007, 152910, 151683,
+    //    152000, 152373, 152760, 152046, 151735, 152334, 152394, 153073, 152908,
+    //    151856, 151953, 153247, 153293, 151903, 153480, 153168, 152478, 153359,
+    //    153429, 151905, 151678, 152567, 152411, 152165, 152556, 153075, 153424,
+    //    151993, 152999, 153078, 152151, 152088, 153389, 152484, 151874, 151670,
+    //    198, 285, 155784, 151669, 152226, 152126, 152638, 153215, 151729,
+    //    152959, 153479, 153059, 151838, 151670, 198, 1782, 155783, 151669,
+    //    153288, 153055, 153314, 152497, 152962, 152741, 152076, 153253, 151670,
+    //    198, 471, 16488, 155825, 151669, 152060, 152916, 151893, 153469, 152501,
+    //    152080, 152743, 151932, 153161, 152096, 152761, 152698, 153401, 153242,
+    //    153336, 152441, 152838, 153467, 152706, 153496, 153310, 152422, 153360,
+    //    153115, 152763, 151998, 152373, 153450, 152554, 151968, 153323, 152055,
+    //    152468, 153111, 153358, 152813, 152010, 151770, 152823, 152960, 151670,
+    //    198, 22627, 155823, 151669, 152814, 152366, 153484, 152931, 153441,
+    //    152164, 152877, 152915, 153463, 151692, 152911, 152747, 152776, 151831,
+    //    153449, 151882, 152975, 152031, 152513, 153150, 152448, 152667, 153133,
+    //    153189, 152619, 153466, 152054, 152106, 153119, 152277, 152439, 153109,
+    //    152997, 152141, 153154, 153256, 153311, 151922, 151670, 198, 1055,
+    //    155781, 151669, 152633, 151850, 153060, 153270, 152560, 153348, 152729,
+    //    151670, 198, 25312, 155803, 151669, 152521, 153403, 152561, 153337,
+    //    153383, 152199, 153493, 153326, 151830, 152254, 152248, 152349, 152153,
+    //    153007, 151823, 153037, 152575, 152457, 152406, 152592, 153116, 153365,
+    //    153456, 151670, 198, 88225, 155817, 151669, 153271, 151925, 152218,
+    //    152418, 152253, 153140, 151903, 153151, 152626, 152338, 152647, 153464,
+    //    152785, 152768, 151711, 152037, 152033, 151804, 152216, 151701, 151855,
+    //    152348, 152995, 152955, 152905, 152342, 152340, 153391, 153453, 152418,
+    //    153415, 151990, 153083, 152884, 151670, 198, 151668, 198, 151645};
+
+    {
+        const std::string inp_txt = common_detokenize(ctx_ttc, codes, true);
+
+        LOG("\n");
+        LOG_INF("codes: '%s'\n", inp_txt.c_str());
+        LOG_INF("%s: codes size: %d\n", __func__, (int) codes.size());
+    }
+
+    // remove all non-audio tokens (i.e. < 151672 || > 155772)
+    codes.erase(std::remove_if(codes.begin(), codes.end(), [](llama_token t) { return t < 151672 || t > 155772; }), codes.end());
+
+    {
+        const std::string inp_txt = common_detokenize(ctx_ttc, codes, true);
+        LOG_INF("codes audio: '%s'\n", inp_txt.c_str());
+        LOG_INF("%s: codes audio size: %d\n", __func__, (int) codes.size());
+    }
+
+    for (auto & token : codes) {
+        token -= 151672;
+    }
+
+    const auto t_voc_start = ggml_time_us();
+
+    const int n_codes = codes.size();
+
+    llama_batch batch = llama_batch_init(n_codes, 0, 1);
+
+    for (size_t i = 0; i < codes.size(); ++i) {
+        common_batch_add(batch, codes[i], i, { 0 }, true); // TODO: all logits?
+    }
+    GGML_ASSERT(batch.n_tokens == n_codes);
+
+    if (llama_decode(ctx_cts, batch) != 0) {
+        LOG_ERR("%s: llama_decode() failed\n", __func__);
+        return 1;
+    }
+
+    llama_synchronize(ctx_cts);
+
+    LOG_INF("%s: time for vocoder:      %.3f ms\n", __func__, (ggml_time_us() - t_voc_start) / 1000.0f);
+
+    const auto t_spec_start = ggml_time_us();
+
+#if 1
+    // spectral operations
+    const int n_embd = llama_n_embd(model_cts);
+    const float * embd = llama_get_embeddings(ctx_cts);
+
+    auto audio = embd_to_audio(embd, n_codes, n_embd, params.cpuparams.n_threads);
+
+#else
+    // read the spectrogram from a file for debugging purposes
+    std::vector<float> audio;
+    {
+        std::ifstream fin("out.bin", std::ios::binary);
+        if (!fin) {
+            LOG_ERR("%s: failed to open file '%s'\n", __func__, "out.bin");
+            return 1;
+        }
+
+        std::vector<float> embd;
+
+        int n_codes;
+        int n_embd;
+
+        fin.read(reinterpret_cast<char *>(&n_codes), sizeof(int));
+        fin.read(reinterpret_cast<char *>(&n_embd), sizeof(int));
+
+        embd.resize(n_codes * n_embd);
+        fin.read(reinterpret_cast<char *>(embd.data()), n_codes * n_embd * sizeof(float));
+        fin.close();
+
+        LOG_INF("%s: n_codes: %d, n_embd: %d\n", __func__, n_codes, n_embd);
+
+        audio = embd_to_audio(embd.data(), n_codes, n_embd, params.cpuparams.n_threads);
+    }
+#endif
+
+    const std::string fname = "output.wav";
+
+    const int n_sr = 24000; // sampling rate
+
+    // zero out first 0.25 seconds
+    for (int i = 0; i < 24000/4; ++i) {
+        audio[i] = 0.0f;
+    }
+
+    LOG_INF("%s: time for spectral ops: %.3f ms\n", __func__, (ggml_time_us() - t_spec_start) / 1000.0f);
+    LOG_INF("%s: total time:            %.3f ms\n", __func__, (ggml_time_us() - t_main_start) / 1000.0f);
+
+    save_wav16(fname, audio, n_sr);
+
+    LOG_INF("%s: audio written to file '%s'\n", __func__, fname.c_str());
+
+    llama_free(ctx_ttc);
+    llama_free_model(model_ttc);
+
+    llama_free(ctx_cts);
+    llama_free_model(model_cts);
+
+    llama_backend_free();
+
+    return 0;
+}
index b0c1ac9ce2b89629ab705aaed7adbae71cf7d898..c714fc8c837bba1c615f3b4f8f914cd6a1a57672 100644 (file)
@@ -1564,17 +1564,6 @@ extern "C" {
         int                   d1, // dilation dimension 1
         bool                  is_2D);
 
-    GGML_API struct ggml_tensor * ggml_conv_depthwise_2d(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,  // convolution kernel
-            struct ggml_tensor  * b,  // data
-            int                  s0,  // stride dimension 0
-            int                  s1,  // stride dimension 1
-            int                  p0,  // padding dimension 0
-            int                  p1,  // padding dimension 1
-            int                  d0,  // dilation dimension 0
-            int                  d1); // dilation dimension 1
-
     GGML_API struct ggml_tensor * ggml_conv_1d(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,   // convolution kernel
@@ -1592,6 +1581,23 @@ extern "C" {
             int                   s,  // stride
             int                   d); // dilation
 
+    // depthwise
+    // TODO: this is very likely wrong for some cases! - needs more testing
+    GGML_API struct ggml_tensor * ggml_conv_1d_dw(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,   // convolution kernel
+            struct ggml_tensor  * b,   // data
+            int                   s0,  // stride
+            int                   p0,  // padding
+            int                   d0); // dilation
+
+    GGML_API struct ggml_tensor * ggml_conv_1d_dw_ph(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,   // convolution kernel
+            struct ggml_tensor  * b,   // data
+            int                   s0,  // stride
+            int                   d0); // dilation
+
     GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,   // convolution kernel
@@ -1611,7 +1617,6 @@ extern "C" {
             int                   d0,  // dilation dimension 0
             int                   d1); // dilation dimension 1
 
-
     // kernel size is a->ne[0] x a->ne[1]
     // stride is equal to kernel size
     // padding is zero
@@ -1638,6 +1643,18 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    // depthwise
+    GGML_API struct ggml_tensor * ggml_conv_2d_dw(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,  // convolution kernel
+            struct ggml_tensor  * b,  // data
+            int                  s0,  // stride dimension 0
+            int                  s1,  // stride dimension 1
+            int                  p0,  // padding dimension 0
+            int                  p1,  // padding dimension 1
+            int                  d0,  // dilation dimension 0
+            int                  d1); // dilation dimension 1
+
     GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
index 0efd2b2ebf780993fc819607b78bd0ffd90764df..2bbe5f48257b2f5c3fe37c4a14cb263eac028c09 100644 (file)
@@ -3760,13 +3760,84 @@ struct ggml_tensor * ggml_clamp(
     return result;
 }
 
-// ggml_conv_1d
-
 static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
     return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
 }
 
-GGML_API struct ggml_tensor * ggml_conv_1d(
+// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
+// a: [OC,IC, KH, KW]
+// b: [N, IC, IH, IW]
+// result: [N, OH, OW, IC*KH*KW]
+struct ggml_tensor * ggml_im2col(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int                   s0,
+        int                   s1,
+        int                   p0,
+        int                   p1,
+        int                   d0,
+        int                   d1,
+        bool                  is_2D,
+        enum ggml_type        dst_type) {
+    if (is_2D) {
+        GGML_ASSERT(a->ne[2] == b->ne[2]);
+    } else {
+        //GGML_ASSERT(b->ne[1] % a->ne[1] == 0);
+        GGML_ASSERT(b->ne[1] == a->ne[1]);
+        GGML_ASSERT(b->ne[3] == 1);
+    }
+
+    const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0;
+    const int64_t OW =         ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
+
+    GGML_ASSERT((!is_2D || OH > 0) && "b too small compared to a");
+    GGML_ASSERT((OW > 0)           && "b too small compared to a");
+
+    const int64_t ne[4] = {
+        is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0],
+        OW,
+        is_2D ? OH : b->ne[2],
+        is_2D ?      b->ne[3] : 1,
+    };
+
+    struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne);
+    int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op     = GGML_OP_IM2COL;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_im2col_back(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int64_t             * ne,
+        int                   s0,
+        int                   s1,
+        int                   p0,
+        int                   p1,
+        int                   d0,
+        int                   d1,
+        bool                  is_2D) {
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+    int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op     = GGML_OP_IM2COL_BACK;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+// ggml_conv_1d
+
+struct ggml_tensor * ggml_conv_1d(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
@@ -3796,137 +3867,75 @@ struct ggml_tensor* ggml_conv_1d_ph(
     return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
 }
 
-// ggml_conv_transpose_1d
-
-static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
-    return (ins - 1) * s - 2 * p + d * (ks - 1) + 1;
-}
+// ggml_conv_1d_dw
 
-GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
+struct ggml_tensor * ggml_conv_1d_dw(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
         int                   s0,
         int                   p0,
         int                   d0) {
-    GGML_ASSERT(ggml_is_matrix(b));
-    GGML_ASSERT(a->ne[2] == b->ne[1]);
-    GGML_ASSERT(a->ne[3] == 1);
+    struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], 1, a->ne[1], a->ne[2]);
+    struct ggml_tensor * new_b = ggml_reshape_4d(ctx, b, b->ne[0], 1, b->ne[1], b->ne[2]);
 
-    GGML_ASSERT(p0 == 0);
-    GGML_ASSERT(d0 == 1);
+    struct ggml_tensor * im2col = ggml_im2col(ctx, new_a, new_b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16);
 
-    const int64_t ne[4] = {
-        ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),
-        a->ne[1], b->ne[2], 1,
-    };
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+    struct ggml_tensor * result = ggml_mul_mat(ctx, im2col, a);
 
-    int32_t params[] = { s0, p0, d0 };
-    ggml_set_op_params(result, params, sizeof(params));
-
-    result->op     = GGML_OP_CONV_TRANSPOSE_1D;
-    result->src[0] = a;
-    result->src[1] = b;
+    result = ggml_reshape_3d(ctx, result, b->ne[0], b->ne[1], 1);
 
     return result;
 }
 
-// ggml_conv_depthwise
+// ggml_conv_1d_dw_ph
 
-struct ggml_tensor * ggml_conv_depthwise_2d(
+struct ggml_tensor * ggml_conv_1d_dw_ph(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
         int                   s0,
-        int                   s1,
-        int                   p0,
-        int                   p1,
-        int                   d0,
-        int                   d1) {
-    struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);
-    struct ggml_tensor * im2col = ggml_im2col(ctx, new_a,
-                                        ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),
-                                        s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW]
-    struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
+        int                   d0) {
+    return ggml_conv_1d_dw(ctx, a, b, s0, a->ne[0] / 2, d0);
+}
 
-    new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2],  new_a->ne[3], 1);                       // [OC,1, KH, KW] => [1, OC, 1, KH * KW]
-    struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b);
-    result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW]
+// ggml_conv_transpose_1d
 
-    return result;
+static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
+    return (ins - 1) * s - 2 * p + d * (ks - 1) + 1;
 }
-// ggml_conv_2d
 
-// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
-// a: [OC,IC, KH, KW]
-// b: [N, IC, IH, IW]
-// result: [N, OH, OW, IC*KH*KW]
-struct ggml_tensor * ggml_im2col(
+GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
         int                   s0,
-        int                   s1,
         int                   p0,
-        int                   p1,
-        int                   d0,
-        int                   d1,
-        bool                  is_2D,
-        enum ggml_type        dst_type) {
-    if(is_2D) {
-        GGML_ASSERT(a->ne[2] == b->ne[2]);
-    } else {
-        GGML_ASSERT(a->ne[1] == b->ne[1]);
-        GGML_ASSERT(b->ne[3] == 1);
-    }
-
-    const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0;
-    const int64_t OW =         ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
+        int                   d0) {
+    GGML_ASSERT(ggml_is_matrix(b));
+    GGML_ASSERT(a->ne[2] == b->ne[1]);
+    GGML_ASSERT(a->ne[3] == 1);
 
-    GGML_ASSERT((!is_2D || OH > 0) && "b too small compared to a");
-    GGML_ASSERT((OW > 0)           && "b too small compared to a");
+    GGML_ASSERT(p0 == 0);
+    GGML_ASSERT(d0 == 1);
 
     const int64_t ne[4] = {
-        is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0],
-        OW,
-        is_2D ? OH : b->ne[2],
-        is_2D ?      b->ne[3] : 1,
+        ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),
+        a->ne[1], b->ne[2], 1,
     };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
-    struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne);
-    int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
+    int32_t params[] = { s0, p0, d0 };
     ggml_set_op_params(result, params, sizeof(params));
 
-    result->op     = GGML_OP_IM2COL;
+    result->op     = GGML_OP_CONV_TRANSPOSE_1D;
     result->src[0] = a;
     result->src[1] = b;
 
     return result;
 }
 
-struct ggml_tensor * ggml_im2col_back(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b,
-        int64_t             * ne,
-        int                   s0,
-        int                   s1,
-        int                   p0,
-        int                   p1,
-        int                   d0,
-        int                   d1,
-        bool                  is_2D) {
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
-    int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
-    ggml_set_op_params(result, params, sizeof(params));
-
-    result->op     = GGML_OP_IM2COL_BACK;
-    result->src[0] = a;
-    result->src[1] = b;
-
-    return result;
-}
+// ggml_conv_2d
 
 // a: [OC,IC, KH, KW]
 // b: [N, IC, IH, IW]
@@ -3973,6 +3982,31 @@ struct ggml_tensor * ggml_conv_2d_s1_ph(
     return ggml_conv_2d(ctx, a, b, 1, 1, a->ne[0] / 2, a->ne[1] / 2, 1, 1);
 }
 
+// ggml_conv_2d_dw
+
+struct ggml_tensor * ggml_conv_2d_dw(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int                   s0,
+        int                   s1,
+        int                   p0,
+        int                   p1,
+        int                   d0,
+        int                   d1) {
+    struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);
+    struct ggml_tensor * im2col = ggml_im2col(ctx, new_a,
+                                        ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),
+                                        s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW]
+    struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
+
+    new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2],  new_a->ne[3], 1);                       // [OC,1, KH, KW] => [1, OC, 1, KH * KW]
+    struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b);
+    result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW]
+
+    return result;
+}
+
 // ggml_conv_transpose_2d_p0
 
 static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
index c2c7cad14e500ae172a867e2175cdbcf4eeba883..a40df974d1fcaad2c83105a64f48148249b3937b 100644 (file)
@@ -90,6 +90,7 @@ class Keys:
         VOCAB_SIZE                        = "{arch}.vocab_size"
         CONTEXT_LENGTH                    = "{arch}.context_length"
         EMBEDDING_LENGTH                  = "{arch}.embedding_length"
+        FEATURES_LENGTH                   = "{arch}.features_length"
         BLOCK_COUNT                       = "{arch}.block_count"
         LEADING_DENSE_BLOCK_COUNT         = "{arch}.leading_dense_block_count"
         FEED_FORWARD_LENGTH               = "{arch}.feed_forward_length"
@@ -122,6 +123,8 @@ class Keys:
         VALUE_LENGTH      = "{arch}.attention.value_length"
         LAYERNORM_EPS     = "{arch}.attention.layer_norm_epsilon"
         LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
+        GROUPNORM_EPS     = "{arch}.attention.group_norm_epsilon"
+        GROUPNORM_GROUPS  = "{arch}.attention.group_norm_groups"
         CAUSAL            = "{arch}.attention.causal"
         Q_LORA_RANK       = "{arch}.attention.q_lora_rank"
         KV_LORA_RANK      = "{arch}.attention.kv_lora_rank"
@@ -155,6 +158,14 @@ class Keys:
     class WKV:
         HEAD_SIZE = "{arch}.wkv.head_size"
 
+    class PosNet:
+        EMBEDDING_LENGTH = "{arch}.posnet.embedding_length"
+        BLOCK_COUNT      = "{arch}.posnet.block_count"
+
+    class ConvNext:
+        EMBEDDING_LENGTH = "{arch}.convnext.embedding_length"
+        BLOCK_COUNT      = "{arch}.convnext.block_count"
+
     class Tokenizer:
         MODEL                = "tokenizer.ggml.model"
         PRE                  = "tokenizer.ggml.pre"
@@ -209,58 +220,59 @@ class GGUFType:
 
 
 class MODEL_ARCH(IntEnum):
-    LLAMA        = auto()
-    FALCON       = auto()
-    BAICHUAN     = auto()
-    GROK         = auto()
-    GPT2         = auto()
-    GPTJ         = auto()
-    GPTNEOX      = auto()
-    MPT          = auto()
-    STARCODER    = auto()
-    REFACT       = auto()
-    BERT         = auto()
-    NOMIC_BERT   = auto()
-    JINA_BERT_V2 = auto()
-    BLOOM        = auto()
-    STABLELM     = auto()
-    QWEN         = auto()
-    QWEN2        = auto()
-    QWEN2MOE     = auto()
-    QWEN2VL      = auto()
-    PHI2         = auto()
-    PHI3         = auto()
-    PLAMO        = auto()
-    CODESHELL    = auto()
-    ORION        = auto()
-    INTERNLM2    = auto()
-    MINICPM      = auto()
-    MINICPM3     = auto()
-    GEMMA        = auto()
-    GEMMA2       = auto()
-    STARCODER2   = auto()
-    RWKV6        = auto()
-    MAMBA        = auto()
-    XVERSE       = auto()
-    COMMAND_R    = auto()
-    DBRX         = auto()
-    OLMO         = auto()
-    OLMO2        = auto()
-    OLMOE        = auto()
-    OPENELM      = auto()
-    ARCTIC       = auto()
-    DEEPSEEK     = auto()
-    DEEPSEEK2    = auto()
-    CHATGLM      = auto()
-    BITNET       = auto()
-    T5           = auto()
-    T5ENCODER    = auto()
-    JAIS         = auto()
-    NEMOTRON     = auto()
-    EXAONE       = auto()
-    GRANITE      = auto()
-    GRANITE_MOE  = auto()
-    CHAMELEON    = auto()
+    LLAMA            = auto()
+    FALCON           = auto()
+    BAICHUAN         = auto()
+    GROK             = auto()
+    GPT2             = auto()
+    GPTJ             = auto()
+    GPTNEOX          = auto()
+    MPT              = auto()
+    STARCODER        = auto()
+    REFACT           = auto()
+    BERT             = auto()
+    NOMIC_BERT       = auto()
+    JINA_BERT_V2     = auto()
+    BLOOM            = auto()
+    STABLELM         = auto()
+    QWEN             = auto()
+    QWEN2            = auto()
+    QWEN2MOE         = auto()
+    QWEN2VL          = auto()
+    PHI2             = auto()
+    PHI3             = auto()
+    PLAMO            = auto()
+    CODESHELL        = auto()
+    ORION            = auto()
+    INTERNLM2        = auto()
+    MINICPM          = auto()
+    MINICPM3         = auto()
+    GEMMA            = auto()
+    GEMMA2           = auto()
+    STARCODER2       = auto()
+    RWKV6            = auto()
+    MAMBA            = auto()
+    XVERSE           = auto()
+    COMMAND_R        = auto()
+    DBRX             = auto()
+    OLMO             = auto()
+    OLMO2            = auto()
+    OLMOE            = auto()
+    OPENELM          = auto()
+    ARCTIC           = auto()
+    DEEPSEEK         = auto()
+    DEEPSEEK2        = auto()
+    CHATGLM          = auto()
+    BITNET           = auto()
+    T5               = auto()
+    T5ENCODER        = auto()
+    JAIS             = auto()
+    NEMOTRON         = auto()
+    EXAONE           = auto()
+    GRANITE          = auto()
+    GRANITE_MOE      = auto()
+    CHAMELEON        = auto()
+    WAVTOKENIZER_DEC = auto()
 
 
 class MODEL_TENSOR(IntEnum):
@@ -370,61 +382,78 @@ class MODEL_TENSOR(IntEnum):
     ENC_OUTPUT_NORM      = auto()
     CLS                  = auto() # classifier
     CLS_OUT              = auto() # classifier output projection
+    CONV1D               = auto()
+    CONVNEXT_DW          = auto()
+    CONVNEXT_NORM        = auto()
+    CONVNEXT_PW1         = auto()
+    CONVNEXT_PW2         = auto()
+    CONVNEXT_GAMMA       = auto()
+    POSNET_CONV1         = auto()
+    POSNET_CONV2         = auto()
+    POSNET_NORM          = auto()
+    POSNET_NORM1         = auto()
+    POSNET_NORM2         = auto()
+    POSNET_ATTN_NORM     = auto()
+    POSNET_ATTN_Q        = auto()
+    POSNET_ATTN_K        = auto()
+    POSNET_ATTN_V        = auto()
+    POSNET_ATTN_OUT      = auto()
 
 
 MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
-    MODEL_ARCH.LLAMA:          "llama",
-    MODEL_ARCH.FALCON:         "falcon",
-    MODEL_ARCH.BAICHUAN:       "baichuan",
-    MODEL_ARCH.GROK:           "grok",
-    MODEL_ARCH.GPT2:           "gpt2",
-    MODEL_ARCH.GPTJ:           "gptj",
-    MODEL_ARCH.GPTNEOX:        "gptneox",
-    MODEL_ARCH.MPT:            "mpt",
-    MODEL_ARCH.STARCODER:      "starcoder",
-    MODEL_ARCH.REFACT:         "refact",
-    MODEL_ARCH.BERT:           "bert",
-    MODEL_ARCH.NOMIC_BERT:     "nomic-bert",
-    MODEL_ARCH.JINA_BERT_V2:   "jina-bert-v2",
-    MODEL_ARCH.BLOOM:          "bloom",
-    MODEL_ARCH.STABLELM:       "stablelm",
-    MODEL_ARCH.QWEN:           "qwen",
-    MODEL_ARCH.QWEN2:          "qwen2",
-    MODEL_ARCH.QWEN2MOE:       "qwen2moe",
-    MODEL_ARCH.QWEN2VL:        "qwen2vl",
-    MODEL_ARCH.PHI2:           "phi2",
-    MODEL_ARCH.PHI3:           "phi3",
-    MODEL_ARCH.PLAMO:          "plamo",
-    MODEL_ARCH.CODESHELL:      "codeshell",
-    MODEL_ARCH.ORION:          "orion",
-    MODEL_ARCH.INTERNLM2:      "internlm2",
-    MODEL_ARCH.MINICPM:        "minicpm",
-    MODEL_ARCH.MINICPM3:       "minicpm3",
-    MODEL_ARCH.GEMMA:          "gemma",
-    MODEL_ARCH.GEMMA2:         "gemma2",
-    MODEL_ARCH.STARCODER2:     "starcoder2",
-    MODEL_ARCH.RWKV6:          "rwkv6",
-    MODEL_ARCH.MAMBA:          "mamba",
-    MODEL_ARCH.XVERSE:         "xverse",
-    MODEL_ARCH.COMMAND_R:      "command-r",
-    MODEL_ARCH.DBRX:           "dbrx",
-    MODEL_ARCH.OLMO:           "olmo",
-    MODEL_ARCH.OLMO2:          "olmo2",
-    MODEL_ARCH.OLMOE:          "olmoe",
-    MODEL_ARCH.OPENELM:        "openelm",
-    MODEL_ARCH.ARCTIC:         "arctic",
-    MODEL_ARCH.DEEPSEEK:       "deepseek",
-    MODEL_ARCH.DEEPSEEK2:      "deepseek2",
-    MODEL_ARCH.CHATGLM:        "chatglm",
-    MODEL_ARCH.BITNET:         "bitnet",
-    MODEL_ARCH.T5:             "t5",
-    MODEL_ARCH.T5ENCODER:      "t5encoder",
-    MODEL_ARCH.JAIS:           "jais",
-    MODEL_ARCH.NEMOTRON:       "nemotron",
-    MODEL_ARCH.EXAONE:         "exaone",
-    MODEL_ARCH.GRANITE:        "granite",
-    MODEL_ARCH.GRANITE_MOE:    "granitemoe",
-    MODEL_ARCH.CHAMELEON:      "chameleon",
+    MODEL_ARCH.LLAMA:            "llama",
+    MODEL_ARCH.FALCON:           "falcon",
+    MODEL_ARCH.BAICHUAN:         "baichuan",
+    MODEL_ARCH.GROK:             "grok",
+    MODEL_ARCH.GPT2:             "gpt2",
+    MODEL_ARCH.GPTJ:             "gptj",
+    MODEL_ARCH.GPTNEOX:          "gptneox",
+    MODEL_ARCH.MPT:              "mpt",
+    MODEL_ARCH.STARCODER:        "starcoder",
+    MODEL_ARCH.REFACT:           "refact",
+    MODEL_ARCH.BERT:             "bert",
+    MODEL_ARCH.NOMIC_BERT:       "nomic-bert",
+    MODEL_ARCH.JINA_BERT_V2:     "jina-bert-v2",
+    MODEL_ARCH.BLOOM:            "bloom",
+    MODEL_ARCH.STABLELM:         "stablelm",
+    MODEL_ARCH.QWEN:             "qwen",
+    MODEL_ARCH.QWEN2:            "qwen2",
+    MODEL_ARCH.QWEN2MOE:         "qwen2moe",
+    MODEL_ARCH.QWEN2VL:          "qwen2vl",
+    MODEL_ARCH.PHI2:             "phi2",
+    MODEL_ARCH.PHI3:             "phi3",
+    MODEL_ARCH.PLAMO:            "plamo",
+    MODEL_ARCH.CODESHELL:        "codeshell",
+    MODEL_ARCH.ORION:            "orion",
+    MODEL_ARCH.INTERNLM2:        "internlm2",
+    MODEL_ARCH.MINICPM:          "minicpm",
+    MODEL_ARCH.MINICPM3:         "minicpm3",
+    MODEL_ARCH.GEMMA:            "gemma",
+    MODEL_ARCH.GEMMA2:           "gemma2",
+    MODEL_ARCH.STARCODER2:       "starcoder2",
+    MODEL_ARCH.RWKV6:            "rwkv6",
+    MODEL_ARCH.MAMBA:            "mamba",
+    MODEL_ARCH.XVERSE:           "xverse",
+    MODEL_ARCH.COMMAND_R:        "command-r",
+    MODEL_ARCH.DBRX:             "dbrx",
+    MODEL_ARCH.OLMO:             "olmo",
+    MODEL_ARCH.OLMO2:            "olmo2",
+    MODEL_ARCH.OLMOE:            "olmoe",
+    MODEL_ARCH.OPENELM:          "openelm",
+    MODEL_ARCH.ARCTIC:           "arctic",
+    MODEL_ARCH.DEEPSEEK:         "deepseek",
+    MODEL_ARCH.DEEPSEEK2:        "deepseek2",
+    MODEL_ARCH.CHATGLM:          "chatglm",
+    MODEL_ARCH.BITNET:           "bitnet",
+    MODEL_ARCH.T5:               "t5",
+    MODEL_ARCH.T5ENCODER:        "t5encoder",
+    MODEL_ARCH.JAIS:             "jais",
+    MODEL_ARCH.NEMOTRON:         "nemotron",
+    MODEL_ARCH.EXAONE:           "exaone",
+    MODEL_ARCH.GRANITE:          "granite",
+    MODEL_ARCH.GRANITE_MOE:      "granitemoe",
+    MODEL_ARCH.CHAMELEON:        "chameleon",
+    MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
 }
 
 TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -534,6 +563,22 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.ENC_OUTPUT_NORM:           "enc.output_norm",
     MODEL_TENSOR.CLS:                       "cls",
     MODEL_TENSOR.CLS_OUT:                   "cls.output",
+    MODEL_TENSOR.CONV1D:                    "conv1d",
+    MODEL_TENSOR.CONVNEXT_DW:               "convnext.{bid}.dw",
+    MODEL_TENSOR.CONVNEXT_NORM:             "convnext.{bid}.norm",
+    MODEL_TENSOR.CONVNEXT_PW1:              "convnext.{bid}.pw1",
+    MODEL_TENSOR.CONVNEXT_PW2:              "convnext.{bid}.pw2",
+    MODEL_TENSOR.CONVNEXT_GAMMA:            "convnext.{bid}.gamma",
+    MODEL_TENSOR.POSNET_CONV1:              "posnet.{bid}.conv1",
+    MODEL_TENSOR.POSNET_CONV2:              "posnet.{bid}.conv2",
+    MODEL_TENSOR.POSNET_NORM:               "posnet.{bid}.norm",
+    MODEL_TENSOR.POSNET_NORM1:              "posnet.{bid}.norm1",
+    MODEL_TENSOR.POSNET_NORM2:              "posnet.{bid}.norm2",
+    MODEL_TENSOR.POSNET_ATTN_NORM:          "posnet.{bid}.attn_norm",
+    MODEL_TENSOR.POSNET_ATTN_Q:             "posnet.{bid}.attn_q",
+    MODEL_TENSOR.POSNET_ATTN_K:             "posnet.{bid}.attn_k",
+    MODEL_TENSOR.POSNET_ATTN_V:             "posnet.{bid}.attn_v",
+    MODEL_TENSOR.POSNET_ATTN_OUT:           "posnet.{bid}.attn_output",
 }
 
 MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@@ -1372,6 +1417,28 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN,
         MODEL_TENSOR.FFN_UP,
     ],
+    MODEL_ARCH.WAVTOKENIZER_DEC: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.TOKEN_EMBD_NORM,
+        MODEL_TENSOR.CONV1D,
+        MODEL_TENSOR.CONVNEXT_DW,
+        MODEL_TENSOR.CONVNEXT_NORM,
+        MODEL_TENSOR.CONVNEXT_PW1,
+        MODEL_TENSOR.CONVNEXT_PW2,
+        MODEL_TENSOR.CONVNEXT_GAMMA,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.POSNET_CONV1,
+        MODEL_TENSOR.POSNET_CONV2,
+        MODEL_TENSOR.POSNET_NORM,
+        MODEL_TENSOR.POSNET_NORM1,
+        MODEL_TENSOR.POSNET_NORM2,
+        MODEL_TENSOR.POSNET_ATTN_NORM,
+        MODEL_TENSOR.POSNET_ATTN_Q,
+        MODEL_TENSOR.POSNET_ATTN_K,
+        MODEL_TENSOR.POSNET_ATTN_V,
+        MODEL_TENSOR.POSNET_ATTN_OUT,
+    ],
     # TODO
 }
 
index 65a64e10dd33fe18937cc58feb526546144bae45..3023b539ae82bb8cf17b54a7bf076b0d0c312155 100644 (file)
@@ -631,6 +631,21 @@ class GGUFWriter:
     def add_embedding_length(self, length: int) -> None:
         self.add_uint32(Keys.LLM.EMBEDDING_LENGTH.format(arch=self.arch), length)
 
+    def add_features_length(self, length: int) -> None:
+        self.add_uint32(Keys.LLM.FEATURES_LENGTH.format(arch=self.arch), length)
+
+    def add_posnet_embedding_length(self, length: int) -> None:
+        self.add_uint32(Keys.PosNet.EMBEDDING_LENGTH.format(arch=self.arch), length)
+
+    def add_posnet_block_count(self, length: int) -> None:
+        self.add_uint32(Keys.PosNet.BLOCK_COUNT.format(arch=self.arch), length)
+
+    def add_convnext_embedding_length(self, length: int) -> None:
+        self.add_uint32(Keys.ConvNext.EMBEDDING_LENGTH.format(arch=self.arch), length)
+
+    def add_convnext_block_count(self, length: int) -> None:
+        self.add_uint32(Keys.ConvNext.BLOCK_COUNT.format(arch=self.arch), length)
+
     def add_block_count(self, length: int) -> None:
         self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length)
 
@@ -727,6 +742,12 @@ class GGUFWriter:
     def add_layer_norm_rms_eps(self, value: float) -> None:
         self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value)
 
+    def add_group_norm_eps(self, value: float) -> None:
+        self.add_float32(Keys.Attention.GROUPNORM_EPS.format(arch=self.arch), value)
+
+    def add_group_norm_groups(self, value: int) -> None:
+        self.add_uint32(Keys.Attention.GROUPNORM_GROUPS.format(arch=self.arch), value)
+
     def add_causal_attention(self, value: bool) -> None:
         self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
 
index 573d0282ea599abb63523a9fbd24b9fe9dba5669..82cdb121a1f2657ebf0128b40a2ee1c943d61c8f 100644 (file)
@@ -42,6 +42,7 @@ class TensorNameMap:
             "emb_ln",                     # nomic-bert
             "transformer.norm",           # openelm
             "rwkv.blocks.0.pre_ln",       # rwkv
+            "backbone.norm",              # wavtokenizer
         ),
 
         # Position embeddings
@@ -60,6 +61,7 @@ class TensorNameMap:
             "lm_head.linear",            # phi2
             "output_layer",              # chatglm
             "head",                      # rwkv
+            "head.out",                  # wavtokenizer
         ),
 
         # Output norm
@@ -80,6 +82,7 @@ class TensorNameMap:
             "transformer.norm",                        # openelm
             "model.norm",                              # nemotron
             "rwkv.ln_out",                             # rwkv
+            "backbone.final_layer_norm",               # wavtokenizer
         ),
 
         # Rope frequencies
@@ -90,6 +93,10 @@ class TensorNameMap:
 
         MODEL_TENSOR.ROPE_FACTORS_LONG: (),
         MODEL_TENSOR.ROPE_FACTORS_SHORT: (),
+
+        MODEL_TENSOR.CONV1D: (
+            "backbone.embed", # roberta
+        ),
     }
 
     block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
@@ -681,6 +688,8 @@ class TensorNameMap:
             "encoder.block.{bid}.layer.1.DenseReluDense.wo", # t5
         ),
 
+        ############################################################################
+        # TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg
         MODEL_TENSOR.ENC_OUTPUT_NORM: (
             "encoder.final_layer_norm", # t5
         ),
@@ -693,6 +702,67 @@ class TensorNameMap:
         MODEL_TENSOR.CLS_OUT: (
             "classifier.out_proj", # roberta
         ),
+        #############################################################################
+
+        MODEL_TENSOR.CONVNEXT_DW: (
+            "backbone.convnext.{bid}.dwconv", # wavtokenizer
+        ),
+
+        MODEL_TENSOR.CONVNEXT_NORM: (
+            "backbone.convnext.{bid}.norm", # wavtokenizer
+        ),
+
+        MODEL_TENSOR.CONVNEXT_PW1: (
+            "backbone.convnext.{bid}.pwconv1", # wavtokenizer
+        ),
+
+        MODEL_TENSOR.CONVNEXT_PW2: (
+            "backbone.convnext.{bid}.pwconv2", # wavtokenizer
+        ),
+
+        MODEL_TENSOR.CONVNEXT_GAMMA: (
+            "backbone.convnext.{bid}.gamma", # wavtokenizer
+        ),
+
+        MODEL_TENSOR.POSNET_CONV1: (
+            "backbone.posnet.{bid}.conv1", # wavtokenizer
+        ),
+
+        MODEL_TENSOR.POSNET_CONV2: (
+            "backbone.posnet.{bid}.conv2", # wavtokenizer
+        ),
+
+        MODEL_TENSOR.POSNET_NORM: (
+            "backbone.posnet.{bid}.norm", # wavtokenizer
+        ),
+
+        MODEL_TENSOR.POSNET_NORM1: (
+            "backbone.posnet.{bid}.norm1", # wavtokenizer
+        ),
+
+        MODEL_TENSOR.POSNET_NORM2: (
+            "backbone.posnet.{bid}.norm2", # wavtokenizer
+        ),
+
+        MODEL_TENSOR.POSNET_ATTN_NORM: (
+            "backbone.posnet.{bid}.norm", # wavtokenizer
+        ),
+
+        MODEL_TENSOR.POSNET_ATTN_Q: (
+            "backbone.posnet.{bid}.q", # wavtokenizer
+        ),
+
+        MODEL_TENSOR.POSNET_ATTN_K: (
+            "backbone.posnet.{bid}.k", # wavtokenizer
+        ),
+
+        MODEL_TENSOR.POSNET_ATTN_V: (
+            "backbone.posnet.{bid}.v", # wavtokenizer
+        ),
+
+        MODEL_TENSOR.POSNET_ATTN_OUT: (
+            "backbone.posnet.{bid}.proj_out", # wavtokenizer
+        ),
     }
 
     # architecture-specific block mappings
index 762067814224e730d8eaf600bed816e77a92003d..f04d5acce279325e5f45b27339c3166803d94e00 100755 (executable)
@@ -136,7 +136,7 @@ def compare_tensors(t1: np.ndarray, t2: np.ndarray, qtype: GGMLQuantizationType)
         logger.debug(f"Sample bad block ({diff_bits[bad_block_id]} differing bits):\n{t1[bad_block_id]}\nReference:\n{t2[bad_block_id]}")
 
         sum_diff_bits = np.sum(diff_bits)
-        logger.debug(f"{sum_diff_bits} bits differ ({100 * sum_diff_bits/(x.size * 8):.6f}%)")
+        logger.debug(f"{sum_diff_bits} bits differ ({100 * sum_diff_bits / (x.size * 8):.6f}%)")
         return False
 
 
index efbb27d21523acbfa8a3253998f05aa062becca1..a4abf395bcd93de266245bd053a4b600d42b7f96 100644 (file)
@@ -482,9 +482,6 @@ extern "C" {
     // Returns the total number of parameters in the model
     LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
 
-    // Get a llama model tensor
-    LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);
-
     // Returns true if the model contains an encoder that requires llama_encode() call
     LLAMA_API bool llama_model_has_encoder(const struct llama_model * model);
 
index e38e598532345ef9cefe69666477c99baf776b08..7f2725f94be1376228171fbb42276f662f033afc 100644 (file)
@@ -1867,6 +1867,10 @@ int32_t llama_detokenize_impl(
                          int32_t   text_len_max,
                             bool   remove_special,
                             bool   unparse_special) {
+    if (vocab.type == LLAMA_VOCAB_TYPE_NONE) {
+        return 0;
+    }
+
     GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
 
     int32_t avail = text_len_max;
index b7b04a41d99e6467f61e5e945d3e1aee784ddc4e..94160d534435f8759bc0b7831c63c19e6d8cd063 100644 (file)
@@ -197,63 +197,65 @@ enum llm_arch {
     LLM_ARCH_GRANITE,
     LLM_ARCH_GRANITE_MOE,
     LLM_ARCH_CHAMELEON,
+    LLM_ARCH_WAVTOKENIZER_DEC,
     LLM_ARCH_UNKNOWN,
 };
 
 static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
-    { LLM_ARCH_LLAMA,           "llama"        },
-    { LLM_ARCH_FALCON,          "falcon"       },
-    { LLM_ARCH_GROK,            "grok"         },
-    { LLM_ARCH_GPT2,            "gpt2"         },
-    { LLM_ARCH_GPTJ,            "gptj"         },
-    { LLM_ARCH_GPTNEOX,         "gptneox"      },
-    { LLM_ARCH_MPT,             "mpt"          },
-    { LLM_ARCH_BAICHUAN,        "baichuan"     },
-    { LLM_ARCH_STARCODER,       "starcoder"    },
-    { LLM_ARCH_REFACT,          "refact"       },
-    { LLM_ARCH_BERT,            "bert"         },
-    { LLM_ARCH_NOMIC_BERT,      "nomic-bert"   },
-    { LLM_ARCH_JINA_BERT_V2,    "jina-bert-v2" },
-    { LLM_ARCH_BLOOM,           "bloom"        },
-    { LLM_ARCH_STABLELM,        "stablelm"     },
-    { LLM_ARCH_QWEN,            "qwen"         },
-    { LLM_ARCH_QWEN2,           "qwen2"        },
-    { LLM_ARCH_QWEN2MOE,        "qwen2moe"     },
-    { LLM_ARCH_QWEN2VL,         "qwen2vl"      },
-    { LLM_ARCH_PHI2,            "phi2"         },
-    { LLM_ARCH_PHI3,            "phi3"         },
-    { LLM_ARCH_PLAMO,           "plamo"        },
-    { LLM_ARCH_CODESHELL,       "codeshell"    },
-    { LLM_ARCH_ORION,           "orion"        },
-    { LLM_ARCH_INTERNLM2,       "internlm2"    },
-    { LLM_ARCH_MINICPM,         "minicpm"      },
-    { LLM_ARCH_MINICPM3,        "minicpm3"     },
-    { LLM_ARCH_GEMMA,           "gemma"        },
-    { LLM_ARCH_GEMMA2,          "gemma2"       },
-    { LLM_ARCH_STARCODER2,      "starcoder2"   },
-    { LLM_ARCH_MAMBA,           "mamba"        },
-    { LLM_ARCH_XVERSE,          "xverse"       },
-    { LLM_ARCH_COMMAND_R,       "command-r"    },
-    { LLM_ARCH_DBRX,            "dbrx"         },
-    { LLM_ARCH_OLMO,            "olmo"         },
-    { LLM_ARCH_OLMO2,           "olmo2"        },
-    { LLM_ARCH_OLMOE,           "olmoe"        },
-    { LLM_ARCH_OPENELM,         "openelm"      },
-    { LLM_ARCH_ARCTIC,          "arctic"       },
-    { LLM_ARCH_DEEPSEEK,        "deepseek"     },
-    { LLM_ARCH_DEEPSEEK2,       "deepseek2"    },
-    { LLM_ARCH_CHATGLM,         "chatglm"      },
-    { LLM_ARCH_BITNET,          "bitnet"       },
-    { LLM_ARCH_T5,              "t5"           },
-    { LLM_ARCH_T5ENCODER,       "t5encoder"    },
-    { LLM_ARCH_JAIS,            "jais"         },
-    { LLM_ARCH_NEMOTRON,        "nemotron"     },
-    { LLM_ARCH_EXAONE,          "exaone"       },
-    { LLM_ARCH_RWKV6,           "rwkv6"        },
-    { LLM_ARCH_GRANITE,         "granite"      },
-    { LLM_ARCH_GRANITE_MOE,     "granitemoe"   },
-    { LLM_ARCH_CHAMELEON,       "chameleon"    },
-    { LLM_ARCH_UNKNOWN,         "(unknown)"    },
+    { LLM_ARCH_LLAMA,            "llama"            },
+    { LLM_ARCH_FALCON,           "falcon"           },
+    { LLM_ARCH_GROK,             "grok"             },
+    { LLM_ARCH_GPT2,             "gpt2"             },
+    { LLM_ARCH_GPTJ,             "gptj"             },
+    { LLM_ARCH_GPTNEOX,          "gptneox"          },
+    { LLM_ARCH_MPT,              "mpt"              },
+    { LLM_ARCH_BAICHUAN,         "baichuan"         },
+    { LLM_ARCH_STARCODER,        "starcoder"        },
+    { LLM_ARCH_REFACT,           "refact"           },
+    { LLM_ARCH_BERT,             "bert"             },
+    { LLM_ARCH_NOMIC_BERT,       "nomic-bert"       },
+    { LLM_ARCH_JINA_BERT_V2,     "jina-bert-v2"     },
+    { LLM_ARCH_BLOOM,            "bloom"            },
+    { LLM_ARCH_STABLELM,         "stablelm"         },
+    { LLM_ARCH_QWEN,             "qwen"             },
+    { LLM_ARCH_QWEN2,            "qwen2"            },
+    { LLM_ARCH_QWEN2MOE,         "qwen2moe"         },
+    { LLM_ARCH_QWEN2VL,          "qwen2vl"          },
+    { LLM_ARCH_PHI2,             "phi2"             },
+    { LLM_ARCH_PHI3,             "phi3"             },
+    { LLM_ARCH_PLAMO,            "plamo"            },
+    { LLM_ARCH_CODESHELL,        "codeshell"        },
+    { LLM_ARCH_ORION,            "orion"            },
+    { LLM_ARCH_INTERNLM2,        "internlm2"        },
+    { LLM_ARCH_MINICPM,          "minicpm"          },
+    { LLM_ARCH_MINICPM3,         "minicpm3"         },
+    { LLM_ARCH_GEMMA,            "gemma"            },
+    { LLM_ARCH_GEMMA2,           "gemma2"           },
+    { LLM_ARCH_STARCODER2,       "starcoder2"       },
+    { LLM_ARCH_MAMBA,            "mamba"            },
+    { LLM_ARCH_XVERSE,           "xverse"           },
+    { LLM_ARCH_COMMAND_R,        "command-r"        },
+    { LLM_ARCH_DBRX,             "dbrx"             },
+    { LLM_ARCH_OLMO,             "olmo"             },
+    { LLM_ARCH_OLMO2,            "olmo2"            },
+    { LLM_ARCH_OLMOE,            "olmoe"            },
+    { LLM_ARCH_OPENELM,          "openelm"          },
+    { LLM_ARCH_ARCTIC,           "arctic"           },
+    { LLM_ARCH_DEEPSEEK,         "deepseek"         },
+    { LLM_ARCH_DEEPSEEK2,        "deepseek2"        },
+    { LLM_ARCH_CHATGLM,          "chatglm"          },
+    { LLM_ARCH_BITNET,           "bitnet"           },
+    { LLM_ARCH_T5,               "t5"               },
+    { LLM_ARCH_T5ENCODER,        "t5encoder"        },
+    { LLM_ARCH_JAIS,             "jais"             },
+    { LLM_ARCH_NEMOTRON,         "nemotron"         },
+    { LLM_ARCH_EXAONE,           "exaone"           },
+    { LLM_ARCH_RWKV6,            "rwkv6"            },
+    { LLM_ARCH_GRANITE,          "granite"          },
+    { LLM_ARCH_GRANITE_MOE,      "granitemoe"       },
+    { LLM_ARCH_CHAMELEON,        "chameleon"        },
+    { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
+    { LLM_ARCH_UNKNOWN,          "(unknown)"        },
 };
 
 enum llm_kv {
@@ -273,6 +275,7 @@ enum llm_kv {
     LLM_KV_VOCAB_SIZE,
     LLM_KV_CONTEXT_LENGTH,
     LLM_KV_EMBEDDING_LENGTH,
+    LLM_KV_FEATURES_LENGTH,
     LLM_KV_BLOCK_COUNT,
     LLM_KV_LEADING_DENSE_BLOCK_COUNT,
     LLM_KV_FEED_FORWARD_LENGTH,
@@ -304,6 +307,8 @@ enum llm_kv {
     LLM_KV_ATTENTION_VALUE_LENGTH,
     LLM_KV_ATTENTION_LAYERNORM_EPS,
     LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,
+    LLM_KV_ATTENTION_GROUPNORM_EPS,
+    LLM_KV_ATTENTION_GROUPNORM_GROUPS,
     LLM_KV_ATTENTION_CAUSAL,
     LLM_KV_ATTENTION_Q_LORA_RANK,
     LLM_KV_ATTENTION_KV_LORA_RANK,
@@ -367,6 +372,12 @@ enum llm_kv {
     LLM_KV_ADAPTER_TYPE,
     LLM_KV_ADAPTER_LORA_ALPHA,
 
+    LLM_KV_POSNET_EMBEDDING_LENGTH,
+    LLM_KV_POSNET_BLOCK_COUNT,
+
+    LLM_KV_CONVNEXT_EMBEDDING_LENGTH,
+    LLM_KV_CONVNEXT_BLOCK_COUNT,
+
     // deprecated:
     LLM_KV_TOKENIZER_PREFIX_ID,
     LLM_KV_TOKENIZER_SUFFIX_ID,
@@ -390,6 +401,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_VOCAB_SIZE,                        "%s.vocab_size"                        },
     { LLM_KV_CONTEXT_LENGTH,                    "%s.context_length"                    },
     { LLM_KV_EMBEDDING_LENGTH,                  "%s.embedding_length"                  },
+    { LLM_KV_FEATURES_LENGTH,                   "%s.features_length"                   },
     { LLM_KV_BLOCK_COUNT,                       "%s.block_count"                       },
     { LLM_KV_LEADING_DENSE_BLOCK_COUNT,         "%s.leading_dense_block_count"         },
     { LLM_KV_FEED_FORWARD_LENGTH,               "%s.feed_forward_length"               },
@@ -421,6 +433,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_ATTENTION_VALUE_LENGTH,           "%s.attention.value_length"           },
     { LLM_KV_ATTENTION_LAYERNORM_EPS,          "%s.attention.layer_norm_epsilon"     },
     { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,      "%s.attention.layer_norm_rms_epsilon" },
+    { LLM_KV_ATTENTION_GROUPNORM_EPS,          "%s.attention.group_norm_epsilon"     },
+    { LLM_KV_ATTENTION_GROUPNORM_GROUPS,       "%s.attention.group_norm_groups"      },
     { LLM_KV_ATTENTION_CAUSAL,                 "%s.attention.causal"                 },
     { LLM_KV_ATTENTION_Q_LORA_RANK,            "%s.attention.q_lora_rank"            },
     { LLM_KV_ATTENTION_KV_LORA_RANK,           "%s.attention.kv_lora_rank"           },
@@ -451,6 +465,12 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
 
     { LLM_KV_WKV_HEAD_SIZE,                    "%s.wkv.head_size" },
 
+    { LLM_KV_POSNET_EMBEDDING_LENGTH,          "%s.posnet.embedding_length" },
+    { LLM_KV_POSNET_BLOCK_COUNT,               "%s.posnet.block_count"      },
+
+    { LLM_KV_CONVNEXT_EMBEDDING_LENGTH,        "%s.convnext.embedding_length" },
+    { LLM_KV_CONVNEXT_BLOCK_COUNT,             "%s.convnext.block_count"      },
+
     { LLM_KV_TOKENIZER_MODEL,                  "tokenizer.ggml.model"                    },
     { LLM_KV_TOKENIZER_PRE,                    "tokenizer.ggml.pre"                      },
     { LLM_KV_TOKENIZER_LIST,                   "tokenizer.ggml.tokens"                   },
@@ -609,6 +629,22 @@ enum llm_tensor {
     LLM_TENSOR_ENC_OUTPUT_NORM,
     LLM_TENSOR_CLS,
     LLM_TENSOR_CLS_OUT,
+    LLM_TENSOR_CONV1D,
+    LLM_TENSOR_CONVNEXT_DW,
+    LLM_TENSOR_CONVNEXT_NORM,
+    LLM_TENSOR_CONVNEXT_PW1,
+    LLM_TENSOR_CONVNEXT_PW2,
+    LLM_TENSOR_CONVNEXT_GAMMA,
+    LLM_TENSOR_POS_NET_CONV1,
+    LLM_TENSOR_POS_NET_CONV2,
+    LLM_TENSOR_POS_NET_NORM,
+    LLM_TENSOR_POS_NET_NORM1,
+    LLM_TENSOR_POS_NET_NORM2,
+    LLM_TENSOR_POS_NET_ATTN_NORM,
+    LLM_TENSOR_POS_NET_ATTN_Q,
+    LLM_TENSOR_POS_NET_ATTN_K,
+    LLM_TENSOR_POS_NET_ATTN_V,
+    LLM_TENSOR_POS_NET_ATTN_OUT,
 };
 
 static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_NAMES = {
@@ -1593,6 +1629,31 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
         },
     },
+    {
+        LLM_ARCH_WAVTOKENIZER_DEC,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,        "token_embd" },
+            { LLM_TENSOR_TOKEN_EMBD_NORM,   "token_embd_norm" },
+            { LLM_TENSOR_CONV1D,            "conv1d" },
+            { LLM_TENSOR_CONVNEXT_DW,       "convnext.%d.dw" },
+            { LLM_TENSOR_CONVNEXT_NORM,     "convnext.%d.norm" },
+            { LLM_TENSOR_CONVNEXT_PW1,      "convnext.%d.pw1" },
+            { LLM_TENSOR_CONVNEXT_PW2,      "convnext.%d.pw2" },
+            { LLM_TENSOR_CONVNEXT_GAMMA,    "convnext.%d.gamma" },
+            { LLM_TENSOR_OUTPUT_NORM,       "output_norm" },
+            { LLM_TENSOR_OUTPUT,            "output" },
+            { LLM_TENSOR_POS_NET_CONV1,     "posnet.%d.conv1" },
+            { LLM_TENSOR_POS_NET_CONV2,     "posnet.%d.conv2" },
+            { LLM_TENSOR_POS_NET_NORM,      "posnet.%d.norm" },
+            { LLM_TENSOR_POS_NET_NORM1,     "posnet.%d.norm1" },
+            { LLM_TENSOR_POS_NET_NORM2,     "posnet.%d.norm2" },
+            { LLM_TENSOR_POS_NET_ATTN_NORM, "posnet.%d.attn_norm" },
+            { LLM_TENSOR_POS_NET_ATTN_Q,    "posnet.%d.attn_q" },
+            { LLM_TENSOR_POS_NET_ATTN_K,    "posnet.%d.attn_k" },
+            { LLM_TENSOR_POS_NET_ATTN_V,    "posnet.%d.attn_v" },
+            { LLM_TENSOR_POS_NET_ATTN_OUT,  "posnet.%d.attn_output" },
+        },
+    },
     {
         LLM_ARCH_UNKNOWN,
         {
@@ -2483,15 +2544,26 @@ static const size_t kiB = 1024;
 static const size_t MiB = 1024*kiB;
 static const size_t GiB = 1024*MiB;
 
+struct llama_hparams_posnet {
+    uint32_t n_embd;
+    uint32_t n_layer;
+};
+
+struct llama_hparams_convnext {
+    uint32_t n_embd;
+    uint32_t n_layer;
+};
+
 struct llama_hparams {
     bool vocab_only;
     bool rope_finetuned;
     bool use_par_res;
     bool swin_norm;
 
-    uint32_t n_vocab;
+    uint32_t n_vocab = 0;
     uint32_t n_ctx_train; // context size the model was trained on
     uint32_t n_embd;
+    uint32_t n_embd_features = 0;
     uint32_t n_layer;
     uint32_t n_rot;
     uint32_t n_swa = 0; // sliding window attention (SWA)
@@ -2502,6 +2574,10 @@ struct llama_hparams {
     uint32_t n_vocab_type = 0; // for BERT-style token types
     uint32_t n_rel_attn_bkts = 0;
 
+    // for WavTokenizer
+    struct llama_hparams_posnet   posnet;
+    struct llama_hparams_convnext convnext;
+
     std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_arr;
     std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr;
     std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
@@ -2516,6 +2592,9 @@ struct llama_hparams {
 
     float f_norm_eps;
     float f_norm_rms_eps;
+    float f_norm_group_eps;
+
+    uint32_t n_norm_groups;
 
     float f_attn_logit_softcapping = 50.0f;
     float f_final_logit_softcapping = 30.0f;
@@ -2561,66 +2640,6 @@ struct llama_hparams {
     enum llama_rope_type         rope_type               = LLAMA_ROPE_TYPE_NONE;
     enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
 
-    bool operator!=(const llama_hparams & other) const {
-        if (this->vocab_only    != other.vocab_only)    return true;
-        if (this->n_vocab       != other.n_vocab)       return true;
-        if (this->n_ctx_train   != other.n_ctx_train)   return true;
-        if (this->n_embd        != other.n_embd)        return true;
-        if (this->n_layer       != other.n_layer)       return true;
-        if (this->n_rot         != other.n_rot)         return true;
-        if (this->n_swa         != other.n_swa)         return true;
-        if (this->n_embd_head_k != other.n_embd_head_k) return true;
-        if (this->n_embd_head_v != other.n_embd_head_v) return true;
-        if (this->n_expert      != other.n_expert)      return true;
-        if (this->n_expert_used != other.n_expert_used) return true;
-
-        if (this->n_head_arr    != other.n_head_arr)    return true;
-        if (this->n_head_kv_arr != other.n_head_kv_arr) return true;
-        if (this->n_ff_arr      != other.n_ff_arr)      return true;
-
-        if (this->n_rel_attn_bkts    != other.n_rel_attn_bkts)    return true;
-        if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true;
-        if (this->n_lora_q           != other.n_lora_q)           return true;
-        if (this->n_lora_kv          != other.n_lora_kv)          return true;
-        if (this->n_ff_exp           != other.n_ff_exp)           return true;
-        if (this->n_ff_shexp         != other.n_ff_shexp)         return true;
-        if (this->n_expert_shared    != other.n_expert_shared)    return true;
-
-        if (this->rope_finetuned  != other.rope_finetuned)  return true;
-        if (this->n_ctx_orig_yarn != other.n_ctx_orig_yarn) return true;
-        if (std::equal(std::begin(this->rope_sections),
-                       std::end(this->rope_sections),
-                       std::begin(other.rope_sections)))    return true;
-
-        if (this->ssm_d_conv  != other.ssm_d_conv)  return true;
-        if (this->ssm_d_inner != other.ssm_d_inner) return true;
-        if (this->ssm_d_state != other.ssm_d_state) return true;
-        if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
-        if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true;
-
-        if (this->rescale_every_n_layers != other.rescale_every_n_layers) return true;
-        if (this->time_mix_extra_dim     != other.time_mix_extra_dim)     return true;
-        if (this->time_decay_extra_dim   != other.time_decay_extra_dim)   return true;
-        if (this->wkv_head_size          != other.wkv_head_size)          return true;
-
-        if (this->dec_start_token_id != other.dec_start_token_id) return true;
-
-        const float EPSILON = 1e-9f;
-
-        if (!is_float_close(this->f_norm_eps,            other.f_norm_eps,            EPSILON)) return true;
-        if (!is_float_close(this->f_norm_rms_eps,        other.f_norm_rms_eps,        EPSILON)) return true;
-        if (!is_float_close(this->rope_attn_factor,      other.rope_attn_factor,      EPSILON)) return true;
-        if (!is_float_close(this->rope_freq_base_train,  other.rope_freq_base_train,  EPSILON)) return true;
-        if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true;
-        if (!is_float_close(this->expert_weights_scale,  other.expert_weights_scale,  EPSILON)) return true;
-        if (!is_float_close(this->rope_yarn_log_mul,     other.rope_yarn_log_mul,     EPSILON)) return true;
-        if (!is_float_close(this->f_residual_scale,      other.f_residual_scale,      EPSILON)) return true;
-        if (!is_float_close(this->f_embedding_scale,     other.f_embedding_scale,     EPSILON)) return true;
-        if (!is_float_close(this->f_attention_scale,     other.f_attention_scale,     EPSILON)) return true;
-
-        return false;
-    }
-
     uint32_t n_head(uint32_t il = 0) const {
         if (il < n_layer) {
             return n_head_arr[il];
@@ -2673,21 +2692,21 @@ struct llama_hparams {
         if (wkv_head_size != 0) {
             // for RWKV models
             return 2 * n_embd;
-        } else {
-            // TODO: maybe support other convolution strides than 1
-            // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
-            return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
         }
+
+        // TODO: maybe support other convolution strides than 1
+        // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
+        return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
     }
 
     uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings
         if (wkv_head_size != 0) {
             // corresponds to RWKV's wkv_states size
             return n_embd * wkv_head_size;
-        } else {
-            // corresponds to Mamba's ssm_states size
-            return ssm_d_state * ssm_d_inner;
         }
+
+        // corresponds to Mamba's ssm_states size
+        return ssm_d_state * ssm_d_inner;
     }
 };
 
@@ -2725,142 +2744,187 @@ struct llama_cparams {
     void * cb_eval_user_data;
 };
 
-// TODO: separate into "llama_layer_enc" and "llama_layer_dec"
-struct llama_layer {
-    llama_layer() {
-        // initialize all pointers to NULL
-        std::memset(this, 0, sizeof(*this));
-    }
+struct llama_layer_posnet {
+    // resnet
+    struct ggml_tensor * norm1   = nullptr;
+    struct ggml_tensor * norm1_b = nullptr;
+
+    struct ggml_tensor * conv1   = nullptr;
+    struct ggml_tensor * conv1_b = nullptr;
 
+    struct ggml_tensor * norm2   = nullptr;
+    struct ggml_tensor * norm2_b = nullptr;
+
+    struct ggml_tensor * conv2   = nullptr;
+    struct ggml_tensor * conv2_b = nullptr;
+
+    // attention
+    struct ggml_tensor * attn_norm   = nullptr;
+    struct ggml_tensor * attn_norm_b = nullptr;
+
+    struct ggml_tensor * attn_q   = nullptr;
+    struct ggml_tensor * attn_q_b = nullptr;
+
+    struct ggml_tensor * attn_k   = nullptr;
+    struct ggml_tensor * attn_k_b = nullptr;
+
+    struct ggml_tensor * attn_v   = nullptr;
+    struct ggml_tensor * attn_v_b = nullptr;
+
+    struct ggml_tensor * attn_o   = nullptr;
+    struct ggml_tensor * attn_o_b = nullptr;
+
+    // normalize
+    struct ggml_tensor * norm   = nullptr;
+    struct ggml_tensor * norm_b = nullptr;
+};
+
+struct llama_layer_convnext {
+    struct ggml_tensor * dw   = nullptr;
+    struct ggml_tensor * dw_b = nullptr;
+
+    struct ggml_tensor * norm   = nullptr;
+    struct ggml_tensor * norm_b = nullptr;
+
+    struct ggml_tensor * pw1   = nullptr;
+    struct ggml_tensor * pw1_b = nullptr;
+
+    struct ggml_tensor * pw2   = nullptr;
+    struct ggml_tensor * pw2_b = nullptr;
+
+    struct ggml_tensor * gamma = nullptr;
+};
+
+struct llama_layer {
     // normalization
-    struct ggml_tensor * attn_norm;
-    struct ggml_tensor * attn_norm_b;
-    struct ggml_tensor * attn_norm_2;
-    struct ggml_tensor * attn_norm_2_b;
-    struct ggml_tensor * attn_q_norm;
-    struct ggml_tensor * attn_q_norm_b;
-    struct ggml_tensor * attn_k_norm;
-    struct ggml_tensor * attn_k_norm_b;
-    struct ggml_tensor * attn_out_norm;
-    struct ggml_tensor * attn_out_norm_b;
-    struct ggml_tensor * attn_q_a_norm;
-    struct ggml_tensor * attn_kv_a_norm;
-    struct ggml_tensor * attn_sub_norm;
-    struct ggml_tensor * attn_post_norm;
-    struct ggml_tensor * ffn_sub_norm;
-    struct ggml_tensor * attn_norm_cross;
-    struct ggml_tensor * attn_norm_enc;
+    struct ggml_tensor * attn_norm       = nullptr;
+    struct ggml_tensor * attn_norm_b     = nullptr;
+    struct ggml_tensor * attn_norm_2     = nullptr;
+    struct ggml_tensor * attn_norm_2_b   = nullptr;
+    struct ggml_tensor * attn_q_norm     = nullptr;
+    struct ggml_tensor * attn_q_norm_b   = nullptr;
+    struct ggml_tensor * attn_k_norm     = nullptr;
+    struct ggml_tensor * attn_k_norm_b   = nullptr;
+    struct ggml_tensor * attn_out_norm   = nullptr;
+    struct ggml_tensor * attn_out_norm_b = nullptr;
+    struct ggml_tensor * attn_q_a_norm   = nullptr;
+    struct ggml_tensor * attn_kv_a_norm  = nullptr;
+    struct ggml_tensor * attn_sub_norm   = nullptr;
+    struct ggml_tensor * attn_post_norm  = nullptr;
+    struct ggml_tensor * ffn_sub_norm    = nullptr;
+    struct ggml_tensor * attn_norm_cross = nullptr;
+    struct ggml_tensor * attn_norm_enc   = nullptr;
 
     // attention
-    struct ggml_tensor * wq;
-    struct ggml_tensor * wk;
-    struct ggml_tensor * wv;
-    struct ggml_tensor * wo;
-    struct ggml_tensor * wqkv;
-    struct ggml_tensor * wq_a;
-    struct ggml_tensor * wq_b;
-    struct ggml_tensor * wkv_a_mqa;
-    struct ggml_tensor * wkv_b;
-    struct ggml_tensor * wq_cross;
-    struct ggml_tensor * wk_cross;
-    struct ggml_tensor * wv_cross;
-    struct ggml_tensor * wo_cross;
-    struct ggml_tensor * wq_enc;
-    struct ggml_tensor * wk_enc;
-    struct ggml_tensor * wv_enc;
-    struct ggml_tensor * wo_enc;
+    struct ggml_tensor * wq        = nullptr;
+    struct ggml_tensor * wk        = nullptr;
+    struct ggml_tensor * wv        = nullptr;
+    struct ggml_tensor * wo        = nullptr;
+    struct ggml_tensor * wqkv      = nullptr;
+    struct ggml_tensor * wq_a      = nullptr;
+    struct ggml_tensor * wq_b      = nullptr;
+    struct ggml_tensor * wkv_a_mqa = nullptr;
+    struct ggml_tensor * wkv_b     = nullptr;
+    struct ggml_tensor * wq_cross  = nullptr;
+    struct ggml_tensor * wk_cross  = nullptr;
+    struct ggml_tensor * wv_cross  = nullptr;
+    struct ggml_tensor * wo_cross  = nullptr;
+    struct ggml_tensor * wq_enc    = nullptr;
+    struct ggml_tensor * wk_enc    = nullptr;
+    struct ggml_tensor * wv_enc    = nullptr;
+    struct ggml_tensor * wo_enc    = nullptr;
 
     // attention bias
-    struct ggml_tensor * bq;
-    struct ggml_tensor * bk;
-    struct ggml_tensor * bv;
-    struct ggml_tensor * bo;
-    struct ggml_tensor * bqkv;
+    struct ggml_tensor * bq   = nullptr;
+    struct ggml_tensor * bk   = nullptr;
+    struct ggml_tensor * bv   = nullptr;
+    struct ggml_tensor * bo   = nullptr;
+    struct ggml_tensor * bqkv = nullptr;
 
     // relative position bias
-    struct ggml_tensor * attn_rel_b;
-    struct ggml_tensor * attn_rel_b_enc;
-    struct ggml_tensor * attn_rel_b_cross;
+    struct ggml_tensor * attn_rel_b       = nullptr;
+    struct ggml_tensor * attn_rel_b_enc   = nullptr;
+    struct ggml_tensor * attn_rel_b_cross = nullptr;
 
     // normalization
-    struct ggml_tensor * ffn_norm;
-    struct ggml_tensor * ffn_norm_b;
-    struct ggml_tensor * ffn_post_norm;
-    struct ggml_tensor * layer_out_norm;
-    struct ggml_tensor * layer_out_norm_b;
-    struct ggml_tensor * ffn_norm_exps;
-    struct ggml_tensor * ffn_norm_enc;
+    struct ggml_tensor * ffn_norm         = nullptr;
+    struct ggml_tensor * ffn_norm_b       = nullptr;
+    struct ggml_tensor * ffn_post_norm    = nullptr;
+    struct ggml_tensor * layer_out_norm   = nullptr;
+    struct ggml_tensor * layer_out_norm_b = nullptr;
+    struct ggml_tensor * ffn_norm_exps    = nullptr;
+    struct ggml_tensor * ffn_norm_enc     = nullptr;
 
     // ff
-    struct ggml_tensor * ffn_gate; // w1
-    struct ggml_tensor * ffn_down; // w2
-    struct ggml_tensor * ffn_up;   // w3
-    struct ggml_tensor * ffn_gate_enc;
-    struct ggml_tensor * ffn_down_enc;
-    struct ggml_tensor * ffn_up_enc;
+    struct ggml_tensor * ffn_gate     = nullptr; // w1
+    struct ggml_tensor * ffn_down     = nullptr; // w2
+    struct ggml_tensor * ffn_up       = nullptr; // w3
+    struct ggml_tensor * ffn_gate_enc = nullptr;
+    struct ggml_tensor * ffn_down_enc = nullptr;
+    struct ggml_tensor * ffn_up_enc   = nullptr;
 
     // ff MoE
-    struct ggml_tensor * ffn_gate_inp;
-    struct ggml_tensor * ffn_gate_exps;
-    struct ggml_tensor * ffn_down_exps;
-    struct ggml_tensor * ffn_up_exps ;
+    struct ggml_tensor * ffn_gate_inp  = nullptr;
+    struct ggml_tensor * ffn_gate_exps = nullptr;
+    struct ggml_tensor * ffn_down_exps = nullptr;
+    struct ggml_tensor * ffn_up_exps   = nullptr;
 
     // ff shared expert (shexp)
-    struct ggml_tensor * ffn_gate_inp_shexp;
-    struct ggml_tensor * ffn_gate_shexp;
-    struct ggml_tensor * ffn_down_shexp;
-    struct ggml_tensor * ffn_up_shexp;
+    struct ggml_tensor * ffn_gate_inp_shexp = nullptr;
+    struct ggml_tensor * ffn_gate_shexp     = nullptr;
+    struct ggml_tensor * ffn_down_shexp     = nullptr;
+    struct ggml_tensor * ffn_up_shexp       = nullptr;
 
     // ff bias
-    struct ggml_tensor * ffn_gate_b;
-    struct ggml_tensor * ffn_down_b; // b2
-    struct ggml_tensor * ffn_up_b; // b3
-    struct ggml_tensor * ffn_act;
+    struct ggml_tensor * ffn_gate_b = nullptr;
+    struct ggml_tensor * ffn_down_b = nullptr; // b2
+    struct ggml_tensor * ffn_up_b   = nullptr; // b3
+    struct ggml_tensor * ffn_act    = nullptr;
 
     // mamba proj
-    struct ggml_tensor * ssm_in;
-    struct ggml_tensor * ssm_x;
-    struct ggml_tensor * ssm_dt;
-    struct ggml_tensor * ssm_out;
+    struct ggml_tensor * ssm_in  = nullptr;
+    struct ggml_tensor * ssm_x   = nullptr;
+    struct ggml_tensor * ssm_dt  = nullptr;
+    struct ggml_tensor * ssm_out = nullptr;
 
     // mamba
-    struct ggml_tensor * ssm_conv1d;
-    struct ggml_tensor * ssm_a;
-    struct ggml_tensor * ssm_d;
+    struct ggml_tensor * ssm_conv1d = nullptr;
+    struct ggml_tensor * ssm_a      = nullptr;
+    struct ggml_tensor * ssm_d      = nullptr;
 
     // mamba bias
-    struct ggml_tensor * ssm_conv1d_b;
-    struct ggml_tensor * ssm_dt_b;
+    struct ggml_tensor * ssm_conv1d_b = nullptr;
+    struct ggml_tensor * ssm_dt_b     = nullptr;
 
     // rwkv
-    struct ggml_tensor * time_mix_w1;
-    struct ggml_tensor * time_mix_w2;
-    struct ggml_tensor * time_mix_lerp_x;
-    struct ggml_tensor * time_mix_lerp_w;
-    struct ggml_tensor * time_mix_lerp_k;
-    struct ggml_tensor * time_mix_lerp_v;
-    struct ggml_tensor * time_mix_lerp_r;
-    struct ggml_tensor * time_mix_lerp_g;
-
-    struct ggml_tensor * time_mix_first;
-    struct ggml_tensor * time_mix_decay;
-    struct ggml_tensor * time_mix_decay_w1;
-    struct ggml_tensor * time_mix_decay_w2;
-    struct ggml_tensor * time_mix_key;
-    struct ggml_tensor * time_mix_value;
-    struct ggml_tensor * time_mix_receptance;
-    struct ggml_tensor * time_mix_gate;
-
-    struct ggml_tensor * time_mix_ln;
-    struct ggml_tensor * time_mix_ln_b;
-    struct ggml_tensor * time_mix_output;
-
-    struct ggml_tensor * channel_mix_lerp_k;
-    struct ggml_tensor * channel_mix_lerp_r;
-
-    struct ggml_tensor * channel_mix_key;
-    struct ggml_tensor * channel_mix_receptance;
-    struct ggml_tensor * channel_mix_value;
+    struct ggml_tensor * time_mix_w1         = nullptr;
+    struct ggml_tensor * time_mix_w2         = nullptr;
+    struct ggml_tensor * time_mix_lerp_x     = nullptr;
+    struct ggml_tensor * time_mix_lerp_w     = nullptr;
+    struct ggml_tensor * time_mix_lerp_k     = nullptr;
+    struct ggml_tensor * time_mix_lerp_v     = nullptr;
+    struct ggml_tensor * time_mix_lerp_r     = nullptr;
+    struct ggml_tensor * time_mix_lerp_g     = nullptr;
+
+    struct ggml_tensor * time_mix_first      = nullptr;
+    struct ggml_tensor * time_mix_decay      = nullptr;
+    struct ggml_tensor * time_mix_decay_w1   = nullptr;
+    struct ggml_tensor * time_mix_decay_w2   = nullptr;
+    struct ggml_tensor * time_mix_key        = nullptr;
+    struct ggml_tensor * time_mix_value      = nullptr;
+    struct ggml_tensor * time_mix_receptance = nullptr;
+    struct ggml_tensor * time_mix_gate       = nullptr;
+
+    struct ggml_tensor * time_mix_ln     = nullptr;
+    struct ggml_tensor * time_mix_ln_b   = nullptr;
+    struct ggml_tensor * time_mix_output = nullptr;
+
+    struct ggml_tensor * channel_mix_lerp_k = nullptr;
+    struct ggml_tensor * channel_mix_lerp_r = nullptr;
+
+    struct ggml_tensor * channel_mix_key        = nullptr;
+    struct ggml_tensor * channel_mix_receptance = nullptr;
+    struct ggml_tensor * channel_mix_value      = nullptr;
 
     // long rope factors
     struct ggml_tensor * rope_long  = nullptr;
@@ -2868,13 +2932,17 @@ struct llama_layer {
     struct ggml_tensor * rope_freqs = nullptr;
 
     // bitnet scale
-    struct ggml_tensor * wq_scale;
-    struct ggml_tensor * wk_scale;
-    struct ggml_tensor * wv_scale;
-    struct ggml_tensor * wo_scale;
-    struct ggml_tensor * ffn_gate_scale;
-    struct ggml_tensor * ffn_up_scale;
-    struct ggml_tensor * ffn_down_scale;
+    struct ggml_tensor * wq_scale       = nullptr;
+    struct ggml_tensor * wk_scale       = nullptr;
+    struct ggml_tensor * wv_scale       = nullptr;
+    struct ggml_tensor * wo_scale       = nullptr;
+    struct ggml_tensor * ffn_gate_scale = nullptr;
+    struct ggml_tensor * ffn_up_scale   = nullptr;
+    struct ggml_tensor * ffn_down_scale = nullptr;
+
+    struct llama_layer_posnet posnet;
+
+    struct llama_layer_convnext convnext;
 };
 
 // very similar to llama_batch,
@@ -3005,6 +3073,9 @@ struct llama_model {
     struct ggml_tensor * cls_out   = nullptr;
     struct ggml_tensor * cls_out_b = nullptr;
 
+    struct ggml_tensor * conv1d = nullptr;
+    struct ggml_tensor * conv1d_b = nullptr;
+
     std::vector<llama_layer> layers;
 
     // gguf metadata
@@ -3089,6 +3160,7 @@ struct llama_sbatch {
     // batch indices of the output
     std::vector<size_t> out_ids;
     std::vector<llama_sbatch_seq> seq;
+
     const llama_batch * batch = nullptr;
 
     // buffers for the ubatch
@@ -3509,6 +3581,17 @@ static int llama_get_device_count(const llama_model & model) {
     return (int) model.devices.size();
 }
 
+static struct ggml_tensor * llama_get_model_tensor(const struct llama_model * model, const char * name) {
+    auto it = std::find_if(model->tensors_by_name.begin(), model->tensors_by_name.end(),
+            [name](const std::pair<std::string, struct ggml_tensor *> & it) {
+                return it.first == name;
+            });
+    if (it == model->tensors_by_name.end()) {
+        return nullptr;
+    }
+    return it->second;
+}
+
 template<typename F>
 static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) {
     ggml_init_params params = {
@@ -3562,7 +3645,9 @@ static bool llama_kv_cache_init(
 
     const struct llama_hparams & hparams = model.hparams;
 
-    const int64_t  n_layer = hparams.n_layer;
+    const int32_t n_layer = hparams.n_layer;
+
+    LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d\n", __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
 
     cache.has_shift = false;
 
@@ -3603,10 +3688,12 @@ static bool llama_kv_cache_init(
     cache.k_l.reserve(n_layer);
     cache.v_l.reserve(n_layer);
 
-    for (int i = 0; i < (int) n_layer; i++) {
+    for (int i = 0; i < n_layer; i++) {
         const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
         const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
 
+        LLAMA_LOG_DEBUG("%s: layer %d: n_embd_k_gqa = %d, n_embd_v_gqa = %d\n", __func__, i, n_embd_k_gqa, n_embd_v_gqa);
+
         ggml_backend_buffer_type_t buft;
         if (offload) {
             auto * dev = model.dev_layer.at(i).dev;
@@ -5519,7 +5606,7 @@ static void llm_load_hparams(
     ml.get_key(LLM_KV_GENERAL_NAME, model.name, false);
 
     // get hparams kv
-    ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab);
+    ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab, false);
 
     // everything past this point is not vocab-related
     if (hparams.vocab_only) {
@@ -5532,6 +5619,16 @@ static void llm_load_hparams(
     ml.get_key(LLM_KV_EXPERT_COUNT,      hparams.n_expert,      false);
     ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
 
+    if (model.arch == LLM_ARCH_WAVTOKENIZER_DEC) {
+        ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features);
+
+        ml.get_key(LLM_KV_POSNET_EMBEDDING_LENGTH, hparams.posnet.n_embd);
+        ml.get_key(LLM_KV_POSNET_BLOCK_COUNT,      hparams.posnet.n_layer);
+
+        ml.get_key(LLM_KV_CONVNEXT_EMBEDDING_LENGTH, hparams.convnext.n_embd);
+        ml.get_key(LLM_KV_CONVNEXT_BLOCK_COUNT,      hparams.convnext.n_layer);
+    }
+
     GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS);
     GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert);
     if (hparams.n_expert > 0) {
@@ -5540,13 +5637,13 @@ static void llm_load_hparams(
         GGML_ASSERT(hparams.n_expert_used == 0);
     }
 
-    // zero-out the per-layer hparams
+    // zero-out the array hparams
     std::fill(hparams.n_head_arr.begin(),    hparams.n_head_arr.end(),    0);
     std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
     std::fill(hparams.n_ff_arr.begin(),      hparams.n_ff_arr.end(),      0);
 
-    ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH,  hparams.n_ff_arr,   hparams.n_layer);
-    ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer);
+    ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH,  hparams.n_ff_arr,   hparams.n_layer, false);
+    ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false);
 
     // n_head_kv is optional, default to n_head
     hparams.n_head_kv_arr = hparams.n_head_arr;
@@ -6291,6 +6388,13 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                }
             } break;
+        case LLM_ARCH_WAVTOKENIZER_DEC:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
+                ml.get_key(LLM_KV_ATTENTION_GROUPNORM_EPS,    hparams.f_norm_group_eps);
+                ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups);
+                ml.get_key(LLM_KV_ATTENTION_CAUSAL,           hparams.causal_attn);
+            } break;
         default: (void)0;
     }
 
@@ -6320,7 +6424,7 @@ static void llm_load_vocab(
         ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model);
         ml.get_key(LLM_KV_TOKENIZER_PRE,   tokenizer_pre, false);
 
-        if (tokenizer_model == "no_vocab") {
+        if (tokenizer_model == "no_vocab" || tokenizer_model == "none") {
             vocab.type = LLAMA_VOCAB_TYPE_NONE;
 
             // default special tokens
@@ -7299,6 +7403,22 @@ static const std::map<llm_tensor, llm_tensor_info> llm_tensor_info_mapping = {
     {LLM_TENSOR_FFN_UP_EXPS,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
     // this tensor is loaded for T5, but never used
     {LLM_TENSOR_DEC_CROSS_ATTN_REL_B,       {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
+    {LLM_TENSOR_CONV1D,                     {LLM_TENSOR_LAYER_INPUT,     GGML_OP_IM2COL}},
+    {LLM_TENSOR_POS_NET_NORM,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_POS_NET_NORM1,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_POS_NET_NORM2,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_POS_NET_CONV1,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}},
+    {LLM_TENSOR_POS_NET_CONV2,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}},
+    {LLM_TENSOR_POS_NET_ATTN_NORM,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_POS_NET_ATTN_Q,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_POS_NET_ATTN_K,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_POS_NET_ATTN_V,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_POS_NET_ATTN_OUT,           {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_CONVNEXT_DW,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}},
+    {LLM_TENSOR_CONVNEXT_NORM,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_CONVNEXT_PW1,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_CONVNEXT_PW2,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_CONVNEXT_GAMMA,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
 };
 
 // checks if the weight tensor can be used with the specified buffer type and device
@@ -7403,6 +7523,12 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
                 ggml_tensor  * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H);
                 op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state);
             } break;
+        case GGML_OP_IM2COL:
+            {
+                const int n_embd = hparams.n_embd;
+                ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd, w->ne[1], 1, 1);
+                op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16);
+            } break;
         default:
             GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name);
     }
@@ -7533,7 +7659,8 @@ static bool llm_load_tensors(
     model.main_gpu     = main_gpu;
     model.n_gpu_layers = n_gpu_layers;
 
-    const int n_layer     = hparams.n_layer;
+    const int n_layer = hparams.n_layer;
+
     bool use_mmap_buffer = true;
 
     // build a list of buffer types for the CPU and GPU devices
@@ -9336,9 +9463,9 @@ static bool llm_load_tensors(
                 } break;
             case LLM_ARCH_CHAMELEON:
                 {
-                 model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
-                 // output
+                    // output
                     model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
                     model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
                     // if output is NULL, init from the input tok embed
@@ -9367,6 +9494,109 @@ static bool llm_load_tensors(
                         layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                     }
                 } break;
+            case LLM_ARCH_WAVTOKENIZER_DEC:
+                {
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd_features, n_vocab}, 0);
+
+                    model.conv1d   = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd_features, hparams.posnet.n_embd}, 0);
+                    model.conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"),   {1, hparams.posnet.n_embd}, 0);
+
+                    // posnet
+                    {
+                        const int64_t n_embd = hparams.posnet.n_embd;
+
+                        for (uint32_t i = 0; i < hparams.posnet.n_layer; ++i) {
+                            auto & layer = model.layers[i].posnet;
+
+                            // posnet:
+                            //
+                            //  - resnet
+                            //  - resnet
+                            //  - attn
+                            //  - resnet
+                            //  - resnet
+                            //  - norm
+                            //
+                            switch (i) {
+                                case 0:
+                                case 1:
+                                case 3:
+                                case 4:
+                                    {
+                                        layer.norm1   = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", i), {1, n_embd}, 0);
+                                        layer.norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias",   i), {1, n_embd}, 0);
+
+                                        layer.conv1   = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "weight", i), {3, n_embd, n_embd}, 0);
+                                        layer.conv1_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "bias",   i), {1, n_embd}, 0);
+
+                                        layer.norm2   = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "weight", i), {1, n_embd}, 0);
+                                        layer.norm2_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "bias",   i), {1, n_embd}, 0);
+
+                                        layer.conv2   = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "weight", i), {3, n_embd, n_embd}, 0);
+                                        layer.conv2_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "bias",   i), {1, n_embd}, 0);
+                                    } break;
+                                case 2:
+                                    {
+                                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0);
+                                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias",   i), {1, n_embd}, 0);
+
+                                        layer.attn_q      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q,    "weight", i), {1, n_embd, n_embd}, 0);
+                                        layer.attn_q_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q,    "bias",   i), {1, n_embd}, 0);
+
+                                        layer.attn_k      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K,    "weight", i), {1, n_embd, n_embd}, 0);
+                                        layer.attn_k_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K,    "bias",   i), {1, n_embd}, 0);
+
+                                        layer.attn_v      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V,    "weight", i), {1, n_embd, n_embd}, 0);
+                                        layer.attn_v_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V,    "bias",   i), {1, n_embd}, 0);
+
+                                        layer.attn_o      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT,  "weight", i), {1, n_embd, n_embd}, 0);
+                                        layer.attn_o_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT,  "bias",   i), {1, n_embd}, 0);
+                                    } break;
+                                case 5:
+                                    {
+                                        layer.norm   = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0);
+                                        layer.norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias",   i), {1, n_embd}, 0);
+                                    } break;
+                                default: GGML_ABORT("unknown posnet layer");
+                            };
+                        }
+                    }
+
+                    GGML_ASSERT(hparams.posnet.n_embd == hparams.convnext.n_embd);
+
+                    model.tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {hparams.posnet.n_embd}, 0);
+                    model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {hparams.posnet.n_embd}, 0);
+
+                    // convnext
+                    {
+                        const int64_t n_embd = hparams.convnext.n_embd;
+
+                        for (uint32_t i = 0; i < hparams.convnext.n_layer; ++i) {
+                            auto & layer = model.layers[i].convnext;
+
+                            layer.dw     = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW,    "weight", i), {7, 1, n_embd}, 0);
+                            layer.dw_b   = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW,    "bias",   i), {1, n_embd}, 0);
+
+                            layer.norm   = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM,  "weight", i), {n_embd}, 0);
+                            layer.norm_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM,  "bias",   i), {n_embd}, 0);
+
+                            layer.pw1    = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1,   "weight", i), {n_embd, n_ff}, 0);
+                            layer.pw1_b  = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1,   "bias",   i), {n_ff}, 0);
+
+                            layer.pw2    = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2,   "weight", i), {n_ff, n_embd}, 0);
+                            layer.pw2_b  = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2,   "bias",   i), {n_embd}, 0);
+
+                            layer.gamma  = create_tensor(tn(LLM_TENSOR_CONVNEXT_GAMMA, "weight", i), {n_embd}, 0);
+                        }
+
+                        // output
+                        model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                        model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                    }
+
+                    model.output   = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0);
+                    model.output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"),   {n_embd}, 0);
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -9586,6 +9816,7 @@ enum llm_ffn_gate_type {
 enum llm_norm_type {
     LLM_NORM,
     LLM_NORM_RMS,
+    LLM_NORM_GROUP,
 };
 
 static struct ggml_tensor * llm_build_inp_embd(
@@ -9606,7 +9837,7 @@ static struct ggml_tensor * llm_build_inp_embd(
 
         inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
     } else {
-       lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
+        lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
         inpL = lctx.inp_embd;
         ggml_set_input(lctx.inp_embd);
     }
@@ -9727,8 +9958,14 @@ static struct ggml_tensor * llm_build_norm(
          const llm_build_cb & cb,
                         int   il) {
     switch (type) {
-        case LLM_NORM:     cur = ggml_norm    (ctx, cur, hparams.f_norm_eps);     break;
-        case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, hparams.f_norm_rms_eps); break;
+        case LLM_NORM:       cur = ggml_norm      (ctx, cur, hparams.f_norm_eps);     break;
+        case LLM_NORM_RMS:   cur = ggml_rms_norm  (ctx, cur, hparams.f_norm_rms_eps); break;
+        case LLM_NORM_GROUP:
+            {
+                cur = ggml_reshape_3d(ctx, cur, cur->ne[0], 1, cur->ne[1]);
+                cur = ggml_group_norm(ctx, cur, hparams.n_norm_groups, hparams.f_norm_group_eps);
+                cur = ggml_reshape_2d(ctx, cur, cur->ne[0],    cur->ne[2]);
+            } break;
     }
 
     if (mw || mb) {
@@ -15854,7 +16091,7 @@ struct llm_build_context {
         return gf;
     }
 
-    struct ggml_cgraph * build_t5_encoder() {
+    struct ggml_cgraph * build_t5_enc() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
@@ -15986,7 +16223,7 @@ struct llm_build_context {
         return gf;
     }
 
-    struct ggml_cgraph * build_t5_decoder() {
+    struct ggml_cgraph * build_t5_dec() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
@@ -16935,6 +17172,158 @@ struct llm_build_context {
 
         return gf;
     }
+
+    struct ggml_cgraph * build_wavtokenizer_dec() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+        cur = ggml_cont(ctx0, ggml_transpose(ctx0, inpL));
+
+        cur = ggml_conv_1d_ph(ctx0, model.conv1d, cur, 1, 1);
+        cur = ggml_add(ctx0, cur, model.conv1d_b);
+
+        // posnet
+        for (uint32_t il = 0; il < hparams.posnet.n_layer; ++il) {
+            const auto & layer = model.layers[il].posnet;
+
+            inpL = cur;
+
+            switch (il) {
+                case 0:
+                case 1:
+                case 3:
+                case 4:
+                    {
+                        cur = llm_build_norm(ctx0, cur, hparams,
+                                layer.norm1,
+                                layer.norm1_b,
+                                LLM_NORM_GROUP, cb, 0);
+
+                        cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur);
+
+                        cur = ggml_conv_1d_ph(ctx0, layer.conv1, cur, 1, 1);
+                        cur = ggml_add(ctx0, cur, layer.conv1_b);
+
+                        cur = llm_build_norm(ctx0, cur, hparams,
+                                layer.norm2,
+                                layer.norm2_b,
+                                LLM_NORM_GROUP, cb, 0);
+
+                        cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur);
+
+                        cur = ggml_conv_1d_ph(ctx0, layer.conv2, cur, 1, 1);
+                        cur = ggml_add(ctx0, cur, layer.conv2_b);
+
+                        cur = ggml_add(ctx0, cur, inpL);
+                    } break;
+                case 2:
+                    {
+                        cur = llm_build_norm(ctx0, cur, hparams,
+                                layer.attn_norm,
+                                layer.attn_norm_b,
+                                LLM_NORM_GROUP, cb, 0);
+
+                        struct ggml_tensor * q;
+                        struct ggml_tensor * k;
+                        struct ggml_tensor * v;
+
+                        q = ggml_conv_1d_ph(ctx0, layer.attn_q, cur, 1, 1);
+                        k = ggml_conv_1d_ph(ctx0, layer.attn_k, cur, 1, 1);
+                        v = ggml_conv_1d_ph(ctx0, layer.attn_v, cur, 1, 1);
+
+                        q = ggml_add(ctx0, q, layer.attn_q_b);
+                        k = ggml_add(ctx0, k, layer.attn_k_b);
+                        v = ggml_add(ctx0, v, layer.attn_v_b);
+
+                        q = ggml_cont(ctx0, ggml_transpose(ctx0, q));
+                        k = ggml_cont(ctx0, ggml_transpose(ctx0, k));
+
+                        struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+
+                        kq = ggml_soft_max_ext(ctx0, kq, nullptr, 1.0f/sqrtf(float(hparams.posnet.n_embd)), 0.0f);
+
+                        cur = ggml_mul_mat(ctx0, kq, v);
+
+                        cur = ggml_conv_1d_ph(ctx0, layer.attn_o, cur, 1, 1);
+                        cur = ggml_add(ctx0, cur, layer.attn_o_b);
+
+                        cur = ggml_add(ctx0, cur, inpL);
+                    } break;
+                case 5:
+                    {
+                        cur = llm_build_norm(ctx0, cur, hparams,
+                                layer.norm,
+                                layer.norm_b,
+                                LLM_NORM_GROUP, cb, 0);
+                    } break;
+                default: GGML_ABORT("unknown posnet layer");
+            };
+        }
+
+        cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.tok_norm,
+                model.tok_norm_b,
+                LLM_NORM, cb, -1);
+
+        cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+
+        inpL = cur;
+
+        // convnext
+        for (uint32_t il = 0; il < hparams.convnext.n_layer; ++il) {
+            const auto & layer = model.layers[il].convnext;
+
+            cur = inpL;
+
+            cur = ggml_conv_1d_dw_ph(ctx0, layer.dw, cur, 1, 1);
+            cur = ggml_add(ctx0, cur, layer.dw_b);
+
+            cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+
+            cur = llm_build_norm(ctx0, cur, hparams,
+                    layer.norm,
+                    layer.norm_b,
+                    LLM_NORM, cb, -1);
+
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                    layer.pw1, layer.pw1_b, NULL,
+                    NULL,      NULL,        NULL,
+                    layer.pw2, layer.pw2_b, NULL,
+                    NULL,
+                    LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
+
+            cur = ggml_mul(ctx0, cur, layer.gamma);
+
+            cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+
+            inpL = ggml_add(ctx0, cur, inpL);
+        }
+
+        cur = inpL;
+
+        cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, cb, -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+        cur = ggml_add(ctx0, cur, model.output_b);
+        cb(cur, "result_embd", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
 };
 
 static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@@ -17181,14 +17570,14 @@ static struct ggml_cgraph * llama_build_graph(
         case LLM_ARCH_T5:
             {
                 if (lctx.is_encoding) {
-                    result = llm.build_t5_encoder();
+                    result = llm.build_t5_enc();
                 } else {
-                    result = llm.build_t5_decoder();
+                    result = llm.build_t5_dec();
                 }
             } break;
         case LLM_ARCH_T5ENCODER:
             {
-                result = llm.build_t5_encoder();
+                result = llm.build_t5_enc();
             } break;
         case LLM_ARCH_JAIS:
             {
@@ -17210,6 +17599,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_chameleon();
             } break;
+        case LLM_ARCH_WAVTOKENIZER_DEC:
+            {
+                result = llm.build_wavtokenizer_dec();
+            } break;
         default:
             GGML_ABORT("fatal error");
     }
@@ -17301,30 +17694,35 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch)
     }
 
     if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
-        GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
-        const int64_t n_tokens = ubatch.n_tokens;
+        //GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
 
-        GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer));
-        int32_t * data = (int32_t *) lctx.inp_out_ids->data;
+        if (!lctx.inp_out_ids) {
+            LLAMA_LOG_WARN("%s: 'lctx.inp_out_ids' is not created\n", __func__);
+        } else {
+            const int64_t n_tokens = ubatch.n_tokens;
 
-        if (lctx.n_outputs == n_tokens) {
-            for (int i = 0; i < n_tokens; ++i) {
-                data[i] = i;
-            }
-        } else if (ubatch.output) {
-            int32_t n_outputs = 0;
-            for (int i = 0; i < n_tokens; ++i) {
-                if (ubatch.output[i]) {
-                    data[n_outputs++] = i;
+            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer));
+            int32_t * data = (int32_t *) lctx.inp_out_ids->data;
+
+            if (lctx.n_outputs == n_tokens) {
+                for (int i = 0; i < n_tokens; ++i) {
+                    data[i] = i;
                 }
+            } else if (ubatch.output) {
+                int32_t n_outputs = 0;
+                for (int i = 0; i < n_tokens; ++i) {
+                    if (ubatch.output[i]) {
+                        data[n_outputs++] = i;
+                    }
+                }
+                // the graph needs to have been passed the correct number of outputs
+                GGML_ASSERT(lctx.n_outputs == n_outputs);
+            } else if (lctx.n_outputs == 1) {
+                // only keep last output
+                data[0] = n_tokens - 1;
+            } else {
+                GGML_ASSERT(lctx.n_outputs == 0);
             }
-            // the graph needs to have been passed the correct number of outputs
-            GGML_ASSERT(lctx.n_outputs == n_outputs);
-        } else if (lctx.n_outputs == 1) {
-            // only keep last output
-            data[0] = n_tokens - 1;
-        } else {
-            GGML_ASSERT(lctx.n_outputs == 0);
         }
     }
 
@@ -17995,6 +18393,7 @@ static int llama_decode_internal(
             embd = nullptr; // do not extract embeddings when not needed
             GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
         }
+
         // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
 
         ggml_backend_sched_alloc_graph(lctx.sched.get(), gf);
@@ -20383,6 +20782,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_T5ENCODER:
         case LLM_ARCH_JAIS:
         case LLM_ARCH_RWKV6:
+        case LLM_ARCH_WAVTOKENIZER_DEC:
             return LLAMA_ROPE_TYPE_NONE;
 
         // use what we call a normal RoPE, operating on pairs of consecutive head values
@@ -20500,17 +20900,6 @@ uint64_t llama_model_n_params(const struct llama_model * model) {
     return model->n_elements;
 }
 
-struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name) {
-    auto it = std::find_if(model->tensors_by_name.begin(), model->tensors_by_name.end(),
-            [name](const std::pair<std::string, struct ggml_tensor *> & it) {
-                return it.first == name;
-            });
-    if (it == model->tensors_by_name.end()) {
-        return nullptr;
-    }
-    return it->second;
-}
-
 bool llama_model_has_encoder(const struct llama_model * model) {
     switch (model->arch) {
         case LLM_ARCH_T5:        return true;