--- /dev/null
+sync.sh
+main
+*.o
--- /dev/null
+main: ggml.o main.o
+ g++ -o main ggml.o main.o
+
+ggml.o: ggml.c ggml.h
+ gcc -O3 -mavx -mavx2 -mfma -mf16c -c ggml.c
+
+main.o: main.cpp ggml.h
+ g++ -O3 -std=c++11 -c main.cpp
+
+# clean up the directory
+clean:
+ rm -f *.o main
+
+# run the program
+run: main
+ ./main
+
+# download the following audio samples into folder "./samples":
+.PHONY: samples
+samples:
+ @echo "Downloading samples..."
+ mkdir -p samples
+ @wget --quiet --show-progress -O samples/gb0.ogg https://upload.wikimedia.org/wikipedia/commons/2/22/George_W._Bush%27s_weekly_radio_address_%28November_1%2C_2008%29.oga
+ @wget --quiet --show-progress -O samples/gb1.ogg https://upload.wikimedia.org/wikipedia/commons/1/1f/George_W_Bush_Columbia_FINAL.ogg
+ @wget --quiet --show-progress -O samples/hp0.ogg https://upload.wikimedia.org/wikipedia/en/d/d4/En.henryfphillips.ogg
+ @echo "Converting to 16-bit WAV ..."
+ @ffmpeg -loglevel -0 -y -i samples/gb0.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/gb0.wav
+ @ffmpeg -loglevel -0 -y -i samples/gb1.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/gb1.wav
+ @ffmpeg -loglevel -0 -y -i samples/hp0.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/hp0.wav
+
+.PHONY: tiny.en
+tiny.en: main
+ @echo "Downloading tiny.en (75 MB just once)"
+ mkdir -p models
+ @if [ ! -f models/ggml-tiny.en.bin ]; then \
+ wget --quiet --show-progress -O models/ggml-tiny.en.bin https://ggml.ggerganov.com/ggml-model-whisper-tiny.en.bin ; \
+ fi
+ @echo "==============================================="
+ @echo "Running tiny.en on all samples in ./samples ..."
+ @echo "==============================================="
+ @echo ""
+ @for f in samples/*.wav; do \
+ echo "----------------------------------------------" ; \
+ echo "[+] Running base.en on $$f ... (run 'ffplay $$f' to listen)" ; \
+ echo "----------------------------------------------" ; \
+ echo "" ; \
+ ./main -m models/ggml-tiny.en.bin -f $$f ; \
+ echo "" ; \
+ done
+
+.PHONY: base.en
+base.en: main
+ @echo "Downloading base.en (142 MB just once)"
+ mkdir -p models
+ @if [ ! -f models/ggml-base.en.bin ]; then \
+ wget --quiet --show-progress -O models/ggml-base.en.bin https://ggml.ggerganov.com/ggml-model-whisper-base.en.bin ; \
+ fi
+ @echo "==============================================="
+ @echo "Running base.en on all samples in ./samples ..."
+ @echo "==============================================="
+ @echo ""
+ @for f in samples/*.wav; do \
+ echo "----------------------------------------------" ; \
+ echo "[+] Running base.en on $$f ... (run 'ffplay $$f' to listen)" ; \
+ echo "----------------------------------------------" ; \
+ echo "" ; \
+ ./main -m models/ggml-base.en.bin -f $$f ; \
+ echo "" ; \
+ done
+
+.PHONY: small.en
+small.en: main
+ @echo "Downloading small.en (466 MB just once)"
+ mkdir -p models
+ @if [ ! -f models/ggml-small.en.bin ]; then \
+ wget --quiet --show-progress -O models/ggml-small.en.bin https://ggml.ggerganov.com/ggml-model-whisper-small.en.bin ; \
+ fi
+ @echo "==============================================="
+ @echo "Running small.en on all samples in ./samples ..."
+ @echo "==============================================="
+ @echo ""
+ @for f in samples/*.wav; do \
+ echo "----------------------------------------------" ; \
+ echo "[+] Running base.en on $$f ... (run 'ffplay $$f' to listen)" ; \
+ echo "----------------------------------------------" ; \
+ echo "" ; \
+ ./main -m models/ggml-small.en.bin -f $$f ; \
+ echo "" ; \
+ done
+
+.PHONY: medium.en
+medium.en: main
+ @echo "Downloading medium.en (1.5 GB just once)"
+ mkdir -p models
+ @if [ ! -f models/ggml-medium.en.bin ]; then \
+ wget --quiet --show-progress -O models/ggml-medium.en.bin https://ggml.ggerganov.com/ggml-model-whisper-medium.en.bin ; \
+ fi
+ @echo "==============================================="
+ @echo "Running medium.en on all samples in ./samples ..."
+ @echo "==============================================="
+ @echo ""
+ @for f in samples/*.wav; do \
+ echo "----------------------------------------------" ; \
+ echo "[+] Running base.en on $$f ... (run 'ffplay $$f' to listen)" ; \
+ echo "----------------------------------------------" ; \
+ echo "" ; \
+ ./main -m models/ggml-medium.en.bin -f $$f ; \
+ echo "" ; \
+ done
--- /dev/null
+# Convert Whisper transformer model from PyTorch to ggml format
+#
+# Usage: python convert-pt-to-ggml.py ~/.cache/whisper/medium.pt ~/path/to/repo/whisper/ ./models/whisper-medium
+#
+# You need to clone the original repo in ~/path/to/repo/whisper/
+#
+# git clone https://github.com/openai/whisper ~/path/to/repo/whisper/
+#
+# It is used to various assets needed by the algorithm:
+#
+# - tokenizer
+# - mel filters
+#
+# Also, you need to have the original models in ~/.cache/whisper/
+# See the original repo for more details.
+#
+# This script loads the specified model and whisper assets and saves them in ggml format.
+# The output is a single binary file containing the following information:
+#
+# - hparams
+# - mel filters
+# - tokenizer vocab
+# - model variables
+#
+# For each variable, write the following:
+#
+# - Number of dimensions (int)
+# - Name length (int)
+# - Dimensions (int[n_dims])
+# - Name (char[name_length])
+# - Data (float[n_dims])
+#
+
+import io
+import os
+import sys
+import struct
+import json
+import code
+import torch
+import numpy as np
+
+from transformers import GPTJForCausalLM
+from transformers import GPT2TokenizerFast
+
+# ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L10-L110
+LANGUAGES = {
+ "en": "english",
+ "zh": "chinese",
+ "de": "german",
+ "es": "spanish",
+ "ru": "russian",
+ "ko": "korean",
+ "fr": "french",
+ "ja": "japanese",
+ "pt": "portuguese",
+ "tr": "turkish",
+ "pl": "polish",
+ "ca": "catalan",
+ "nl": "dutch",
+ "ar": "arabic",
+ "sv": "swedish",
+ "it": "italian",
+ "id": "indonesian",
+ "hi": "hindi",
+ "fi": "finnish",
+ "vi": "vietnamese",
+ "iw": "hebrew",
+ "uk": "ukrainian",
+ "el": "greek",
+ "ms": "malay",
+ "cs": "czech",
+ "ro": "romanian",
+ "da": "danish",
+ "hu": "hungarian",
+ "ta": "tamil",
+ "no": "norwegian",
+ "th": "thai",
+ "ur": "urdu",
+ "hr": "croatian",
+ "bg": "bulgarian",
+ "lt": "lithuanian",
+ "la": "latin",
+ "mi": "maori",
+ "ml": "malayalam",
+ "cy": "welsh",
+ "sk": "slovak",
+ "te": "telugu",
+ "fa": "persian",
+ "lv": "latvian",
+ "bn": "bengali",
+ "sr": "serbian",
+ "az": "azerbaijani",
+ "sl": "slovenian",
+ "kn": "kannada",
+ "et": "estonian",
+ "mk": "macedonian",
+ "br": "breton",
+ "eu": "basque",
+ "is": "icelandic",
+ "hy": "armenian",
+ "ne": "nepali",
+ "mn": "mongolian",
+ "bs": "bosnian",
+ "kk": "kazakh",
+ "sq": "albanian",
+ "sw": "swahili",
+ "gl": "galician",
+ "mr": "marathi",
+ "pa": "punjabi",
+ "si": "sinhala",
+ "km": "khmer",
+ "sn": "shona",
+ "yo": "yoruba",
+ "so": "somali",
+ "af": "afrikaans",
+ "oc": "occitan",
+ "ka": "georgian",
+ "be": "belarusian",
+ "tg": "tajik",
+ "sd": "sindhi",
+ "gu": "gujarati",
+ "am": "amharic",
+ "yi": "yiddish",
+ "lo": "lao",
+ "uz": "uzbek",
+ "fo": "faroese",
+ "ht": "haitian creole",
+ "ps": "pashto",
+ "tk": "turkmen",
+ "nn": "nynorsk",
+ "mt": "maltese",
+ "sa": "sanskrit",
+ "lb": "luxembourgish",
+ "my": "myanmar",
+ "bo": "tibetan",
+ "tl": "tagalog",
+ "mg": "malagasy",
+ "as": "assamese",
+ "tt": "tatar",
+ "haw": "hawaiian",
+ "ln": "lingala",
+ "ha": "hausa",
+ "ba": "bashkir",
+ "jw": "javanese",
+ "su": "sundanese",
+}
+
+# ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L273-L292
+def build_tokenizer(path_to_whisper_repo: str, name: str = "gpt2"):
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
+ path = os.path.join(path_to_whisper_repo, "whisper/assets", name)
+ tokenizer = GPT2TokenizerFast.from_pretrained(path)
+
+ specials = [
+ "<|startoftranscript|>",
+ *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
+ "<|translate|>",
+ "<|transcribe|>",
+ "<|startoflm|>",
+ "<|startofprev|>",
+ "<|nocaptions|>",
+ "<|notimestamps|>",
+ ]
+
+ tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
+ return tokenizer
+
+# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8+n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+if len(sys.argv) < 4:
+ print("Usage: convert-pt-to-ggml.py model.pt path-to-whisper-repo dir-output [use-f32]\n")
+ sys.exit(1)
+
+fname_inp = sys.argv[1]
+dir_whisper = sys.argv[2]
+dir_out = sys.argv[3]
+
+# try to load PyTorch binary data
+try:
+ model_bytes = open(fname_inp, "rb").read()
+ with io.BytesIO(model_bytes) as fp:
+ checkpoint = torch.load(fp, map_location="cpu")
+except:
+ print("Error: failed to load PyTorch model file: %s" % fname_inp)
+ sys.exit(1)
+
+hparams = checkpoint["dims"]
+print("hparams:", hparams)
+
+list_vars = checkpoint["model_state_dict"]
+
+#print(list_vars['encoder.positional_embedding'])
+#print(list_vars['encoder.conv1.weight'])
+#print(list_vars['encoder.conv1.weight'].shape)
+
+# load mel filters
+n_mels = hparams["n_mels"]
+with np.load(os.path.join(dir_whisper, "whisper/assets", "mel_filters.npz")) as f:
+ filters = torch.from_numpy(f[f"mel_{n_mels}"])
+ #print (filters)
+
+#code.interact(local=locals())
+
+multilingual = hparams["n_vocab"] == 51865
+tokenizer = build_tokenizer(dir_whisper, multilingual and "multilingual" or "gpt2")
+
+#print(tokenizer)
+#print(tokenizer.name_or_path)
+#print(len(tokenizer.additional_special_tokens))
+dir_tokenizer = tokenizer.name_or_path
+
+# output in the same directory as the model
+fname_out = dir_out + "/ggml-model.bin"
+
+with open(dir_tokenizer + "/vocab.json", "r") as f:
+ tokens = json.load(f)
+
+# use 16-bit or 32-bit floats
+use_f16 = True
+if len(sys.argv) > 4:
+ use_f16 = False
+ fname_out = dir_out + "/ggml-model-f32.bin"
+
+fout = open(fname_out, "wb")
+
+fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
+fout.write(struct.pack("i", hparams["n_vocab"]))
+fout.write(struct.pack("i", hparams["n_audio_ctx"]))
+fout.write(struct.pack("i", hparams["n_audio_state"]))
+fout.write(struct.pack("i", hparams["n_audio_head"]))
+fout.write(struct.pack("i", hparams["n_audio_layer"]))
+fout.write(struct.pack("i", hparams["n_text_ctx"]))
+fout.write(struct.pack("i", hparams["n_text_state"]))
+fout.write(struct.pack("i", hparams["n_text_head"]))
+fout.write(struct.pack("i", hparams["n_text_layer"]))
+fout.write(struct.pack("i", hparams["n_mels"]))
+fout.write(struct.pack("i", use_f16))
+
+# write mel filters
+fout.write(struct.pack("i", filters.shape[0]))
+fout.write(struct.pack("i", filters.shape[1]))
+for i in range(filters.shape[0]):
+ for j in range(filters.shape[1]):
+ fout.write(struct.pack("f", filters[i][j]))
+
+byte_encoder = bytes_to_unicode()
+byte_decoder = {v:k for k, v in byte_encoder.items()}
+
+fout.write(struct.pack("i", len(tokens)))
+
+for key in tokens:
+ text = bytearray([byte_decoder[c] for c in key]).decode('utf-8', errors='replace').encode('utf-8')
+ fout.write(struct.pack("i", len(text)))
+ fout.write(text)
+
+for name in list_vars.keys():
+ data = list_vars[name].squeeze().numpy()
+ print("Processing variable: " + name + " with shape: ", data.shape)
+
+ # reshape conv bias from [n] to [n, 1]
+ if name == "encoder.conv1.bias" or \
+ name == "encoder.conv2.bias":
+ data = data.reshape(data.shape[0], 1)
+ print(" Reshaped variable: " + name + " to shape: ", data.shape)
+
+ n_dims = len(data.shape);
+
+ # looks like the whisper models are in f16 by default
+ # so we need to convert the small tensors to f32 until we fully support f16 in ggml
+ # ftype == 0 -> float32, ftype == 1 -> float16
+ ftype = 1;
+ if use_f16:
+ if n_dims < 2 or \
+ name == "encoder.conv1.bias" or \
+ name == "encoder.conv2.bias" or \
+ name == "encoder.positional_embedding" or \
+ name == "decoder.positional_embedding":
+ ftype = 0
+ data = data.astype(np.float32)
+ print(" Converting to float32")
+ data = data.astype(np.float32)
+ ftype = 0
+ else:
+ data = data.astype(np.float32)
+ ftype = 0
+
+ #if name.startswith("encoder"):
+ # if name.endswith("mlp.0.weight") or \
+ # name.endswith("mlp.2.weight"):
+ # print(" Transposing")
+ # data = data.transpose()
+
+ # header
+ str = name.encode('utf-8')
+ fout.write(struct.pack("iii", n_dims, len(str), ftype))
+ for i in range(n_dims):
+ fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
+ fout.write(str);
+
+ # data
+ data.tofile(fout)
+
+fout.close()
+
+print("Done. Output file: " + fname_out)
+print("")
--- /dev/null
+/*
+WAV audio loader and writer. Choice of public domain or MIT-0. See license statements at the end of this file.
+dr_wav - v0.12.16 - 2020-12-02
+
+David Reid - mackron@gmail.com
+
+GitHub: https://github.com/mackron/dr_libs
+*/
+
+/*
+RELEASE NOTES - VERSION 0.12
+============================
+Version 0.12 includes breaking changes to custom chunk handling.
+
+
+Changes to Chunk Callback
+-------------------------
+dr_wav supports the ability to fire a callback when a chunk is encounted (except for WAVE and FMT chunks). The callback has been updated to include both the
+container (RIFF or Wave64) and the FMT chunk which contains information about the format of the data in the wave file.
+
+Previously, there was no direct way to determine the container, and therefore no way to discriminate against the different IDs in the chunk header (RIFF and
+Wave64 containers encode chunk ID's differently). The `container` parameter can be used to know which ID to use.
+
+Sometimes it can be useful to know the data format at the time the chunk callback is fired. A pointer to a `drwav_fmt` object is now passed into the chunk
+callback which will give you information about the data format. To determine the sample format, use `drwav_fmt_get_format()`. This will return one of the
+`DR_WAVE_FORMAT_*` tokens.
+*/
+
+/*
+Introduction
+============
+This is a single file library. To use it, do something like the following in one .c file.
+
+ ```c
+ #define DR_WAV_IMPLEMENTATION
+ #include "dr_wav.h"
+ ```
+
+You can then #include this file in other parts of the program as you would with any other header file. Do something like the following to read audio data:
+
+ ```c
+ drwav wav;
+ if (!drwav_init_file(&wav, "my_song.wav", NULL)) {
+ // Error opening WAV file.
+ }
+
+ drwav_int32* pDecodedInterleavedPCMFrames = malloc(wav.totalPCMFrameCount * wav.channels * sizeof(drwav_int32));
+ size_t numberOfSamplesActuallyDecoded = drwav_read_pcm_frames_s32(&wav, wav.totalPCMFrameCount, pDecodedInterleavedPCMFrames);
+
+ ...
+
+ drwav_uninit(&wav);
+ ```
+
+If you just want to quickly open and read the audio data in a single operation you can do something like this:
+
+ ```c
+ unsigned int channels;
+ unsigned int sampleRate;
+ drwav_uint64 totalPCMFrameCount;
+ float* pSampleData = drwav_open_file_and_read_pcm_frames_f32("my_song.wav", &channels, &sampleRate, &totalPCMFrameCount, NULL);
+ if (pSampleData == NULL) {
+ // Error opening and reading WAV file.
+ }
+
+ ...
+
+ drwav_free(pSampleData);
+ ```
+
+The examples above use versions of the API that convert the audio data to a consistent format (32-bit signed PCM, in this case), but you can still output the
+audio data in its internal format (see notes below for supported formats):
+
+ ```c
+ size_t framesRead = drwav_read_pcm_frames(&wav, wav.totalPCMFrameCount, pDecodedInterleavedPCMFrames);
+ ```
+
+You can also read the raw bytes of audio data, which could be useful if dr_wav does not have native support for a particular data format:
+
+ ```c
+ size_t bytesRead = drwav_read_raw(&wav, bytesToRead, pRawDataBuffer);
+ ```
+
+dr_wav can also be used to output WAV files. This does not currently support compressed formats. To use this, look at `drwav_init_write()`,
+`drwav_init_file_write()`, etc. Use `drwav_write_pcm_frames()` to write samples, or `drwav_write_raw()` to write raw data in the "data" chunk.
+
+ ```c
+ drwav_data_format format;
+ format.container = drwav_container_riff; // <-- drwav_container_riff = normal WAV files, drwav_container_w64 = Sony Wave64.
+ format.format = DR_WAVE_FORMAT_PCM; // <-- Any of the DR_WAVE_FORMAT_* codes.
+ format.channels = 2;
+ format.sampleRate = 44100;
+ format.bitsPerSample = 16;
+ drwav_init_file_write(&wav, "data/recording.wav", &format, NULL);
+
+ ...
+
+ drwav_uint64 framesWritten = drwav_write_pcm_frames(pWav, frameCount, pSamples);
+ ```
+
+dr_wav has seamless support the Sony Wave64 format. The decoder will automatically detect it and it should Just Work without any manual intervention.
+
+
+Build Options
+=============
+#define these options before including this file.
+
+#define DR_WAV_NO_CONVERSION_API
+ Disables conversion APIs such as `drwav_read_pcm_frames_f32()` and `drwav_s16_to_f32()`.
+
+#define DR_WAV_NO_STDIO
+ Disables APIs that initialize a decoder from a file such as `drwav_init_file()`, `drwav_init_file_write()`, etc.
+
+
+
+Notes
+=====
+- Samples are always interleaved.
+- The default read function does not do any data conversion. Use `drwav_read_pcm_frames_f32()`, `drwav_read_pcm_frames_s32()` and `drwav_read_pcm_frames_s16()`
+ to read and convert audio data to 32-bit floating point, signed 32-bit integer and signed 16-bit integer samples respectively. Tested and supported internal
+ formats include the following:
+ - Unsigned 8-bit PCM
+ - Signed 12-bit PCM
+ - Signed 16-bit PCM
+ - Signed 24-bit PCM
+ - Signed 32-bit PCM
+ - IEEE 32-bit floating point
+ - IEEE 64-bit floating point
+ - A-law and u-law
+ - Microsoft ADPCM
+ - IMA ADPCM (DVI, format code 0x11)
+- dr_wav will try to read the WAV file as best it can, even if it's not strictly conformant to the WAV format.
+*/
+
+#ifndef dr_wav_h
+#define dr_wav_h
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#define DRWAV_STRINGIFY(x) #x
+#define DRWAV_XSTRINGIFY(x) DRWAV_STRINGIFY(x)
+
+#define DRWAV_VERSION_MAJOR 0
+#define DRWAV_VERSION_MINOR 12
+#define DRWAV_VERSION_REVISION 16
+#define DRWAV_VERSION_STRING DRWAV_XSTRINGIFY(DRWAV_VERSION_MAJOR) "." DRWAV_XSTRINGIFY(DRWAV_VERSION_MINOR) "." DRWAV_XSTRINGIFY(DRWAV_VERSION_REVISION)
+
+#include <stddef.h> /* For size_t. */
+
+/* Sized types. */
+typedef signed char drwav_int8;
+typedef unsigned char drwav_uint8;
+typedef signed short drwav_int16;
+typedef unsigned short drwav_uint16;
+typedef signed int drwav_int32;
+typedef unsigned int drwav_uint32;
+#if defined(_MSC_VER)
+ typedef signed __int64 drwav_int64;
+ typedef unsigned __int64 drwav_uint64;
+#else
+ #if defined(__clang__) || (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)))
+ #pragma GCC diagnostic push
+ #pragma GCC diagnostic ignored "-Wlong-long"
+ #if defined(__clang__)
+ #pragma GCC diagnostic ignored "-Wc++11-long-long"
+ #endif
+ #endif
+ typedef signed long long drwav_int64;
+ typedef unsigned long long drwav_uint64;
+ #if defined(__clang__) || (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)))
+ #pragma GCC diagnostic pop
+ #endif
+#endif
+#if defined(__LP64__) || defined(_WIN64) || (defined(__x86_64__) && !defined(__ILP32__)) || defined(_M_X64) || defined(__ia64) || defined (_M_IA64) || defined(__aarch64__) || defined(__powerpc64__)
+ typedef drwav_uint64 drwav_uintptr;
+#else
+ typedef drwav_uint32 drwav_uintptr;
+#endif
+typedef drwav_uint8 drwav_bool8;
+typedef drwav_uint32 drwav_bool32;
+#define DRWAV_TRUE 1
+#define DRWAV_FALSE 0
+
+#if !defined(DRWAV_API)
+ #if defined(DRWAV_DLL)
+ #if defined(_WIN32)
+ #define DRWAV_DLL_IMPORT __declspec(dllimport)
+ #define DRWAV_DLL_EXPORT __declspec(dllexport)
+ #define DRWAV_DLL_PRIVATE static
+ #else
+ #if defined(__GNUC__) && __GNUC__ >= 4
+ #define DRWAV_DLL_IMPORT __attribute__((visibility("default")))
+ #define DRWAV_DLL_EXPORT __attribute__((visibility("default")))
+ #define DRWAV_DLL_PRIVATE __attribute__((visibility("hidden")))
+ #else
+ #define DRWAV_DLL_IMPORT
+ #define DRWAV_DLL_EXPORT
+ #define DRWAV_DLL_PRIVATE static
+ #endif
+ #endif
+
+ #if defined(DR_WAV_IMPLEMENTATION) || defined(DRWAV_IMPLEMENTATION)
+ #define DRWAV_API DRWAV_DLL_EXPORT
+ #else
+ #define DRWAV_API DRWAV_DLL_IMPORT
+ #endif
+ #define DRWAV_PRIVATE DRWAV_DLL_PRIVATE
+ #else
+ #define DRWAV_API extern
+ #define DRWAV_PRIVATE static
+ #endif
+#endif
+
+typedef drwav_int32 drwav_result;
+#define DRWAV_SUCCESS 0
+#define DRWAV_ERROR -1 /* A generic error. */
+#define DRWAV_INVALID_ARGS -2
+#define DRWAV_INVALID_OPERATION -3
+#define DRWAV_OUT_OF_MEMORY -4
+#define DRWAV_OUT_OF_RANGE -5
+#define DRWAV_ACCESS_DENIED -6
+#define DRWAV_DOES_NOT_EXIST -7
+#define DRWAV_ALREADY_EXISTS -8
+#define DRWAV_TOO_MANY_OPEN_FILES -9
+#define DRWAV_INVALID_FILE -10
+#define DRWAV_TOO_BIG -11
+#define DRWAV_PATH_TOO_LONG -12
+#define DRWAV_NAME_TOO_LONG -13
+#define DRWAV_NOT_DIRECTORY -14
+#define DRWAV_IS_DIRECTORY -15
+#define DRWAV_DIRECTORY_NOT_EMPTY -16
+#define DRWAV_END_OF_FILE -17
+#define DRWAV_NO_SPACE -18
+#define DRWAV_BUSY -19
+#define DRWAV_IO_ERROR -20
+#define DRWAV_INTERRUPT -21
+#define DRWAV_UNAVAILABLE -22
+#define DRWAV_ALREADY_IN_USE -23
+#define DRWAV_BAD_ADDRESS -24
+#define DRWAV_BAD_SEEK -25
+#define DRWAV_BAD_PIPE -26
+#define DRWAV_DEADLOCK -27
+#define DRWAV_TOO_MANY_LINKS -28
+#define DRWAV_NOT_IMPLEMENTED -29
+#define DRWAV_NO_MESSAGE -30
+#define DRWAV_BAD_MESSAGE -31
+#define DRWAV_NO_DATA_AVAILABLE -32
+#define DRWAV_INVALID_DATA -33
+#define DRWAV_TIMEOUT -34
+#define DRWAV_NO_NETWORK -35
+#define DRWAV_NOT_UNIQUE -36
+#define DRWAV_NOT_SOCKET -37
+#define DRWAV_NO_ADDRESS -38
+#define DRWAV_BAD_PROTOCOL -39
+#define DRWAV_PROTOCOL_UNAVAILABLE -40
+#define DRWAV_PROTOCOL_NOT_SUPPORTED -41
+#define DRWAV_PROTOCOL_FAMILY_NOT_SUPPORTED -42
+#define DRWAV_ADDRESS_FAMILY_NOT_SUPPORTED -43
+#define DRWAV_SOCKET_NOT_SUPPORTED -44
+#define DRWAV_CONNECTION_RESET -45
+#define DRWAV_ALREADY_CONNECTED -46
+#define DRWAV_NOT_CONNECTED -47
+#define DRWAV_CONNECTION_REFUSED -48
+#define DRWAV_NO_HOST -49
+#define DRWAV_IN_PROGRESS -50
+#define DRWAV_CANCELLED -51
+#define DRWAV_MEMORY_ALREADY_MAPPED -52
+#define DRWAV_AT_END -53
+
+/* Common data formats. */
+#define DR_WAVE_FORMAT_PCM 0x1
+#define DR_WAVE_FORMAT_ADPCM 0x2
+#define DR_WAVE_FORMAT_IEEE_FLOAT 0x3
+#define DR_WAVE_FORMAT_ALAW 0x6
+#define DR_WAVE_FORMAT_MULAW 0x7
+#define DR_WAVE_FORMAT_DVI_ADPCM 0x11
+#define DR_WAVE_FORMAT_EXTENSIBLE 0xFFFE
+
+/* Constants. */
+#ifndef DRWAV_MAX_SMPL_LOOPS
+#define DRWAV_MAX_SMPL_LOOPS 1
+#endif
+
+/* Flags to pass into drwav_init_ex(), etc. */
+#define DRWAV_SEQUENTIAL 0x00000001
+
+DRWAV_API void drwav_version(drwav_uint32* pMajor, drwav_uint32* pMinor, drwav_uint32* pRevision);
+DRWAV_API const char* drwav_version_string(void);
+
+typedef enum
+{
+ drwav_seek_origin_start,
+ drwav_seek_origin_current
+} drwav_seek_origin;
+
+typedef enum
+{
+ drwav_container_riff,
+ drwav_container_w64,
+ drwav_container_rf64
+} drwav_container;
+
+typedef struct
+{
+ union
+ {
+ drwav_uint8 fourcc[4];
+ drwav_uint8 guid[16];
+ } id;
+
+ /* The size in bytes of the chunk. */
+ drwav_uint64 sizeInBytes;
+
+ /*
+ RIFF = 2 byte alignment.
+ W64 = 8 byte alignment.
+ */
+ unsigned int paddingSize;
+} drwav_chunk_header;
+
+typedef struct
+{
+ /*
+ The format tag exactly as specified in the wave file's "fmt" chunk. This can be used by applications
+ that require support for data formats not natively supported by dr_wav.
+ */
+ drwav_uint16 formatTag;
+
+ /* The number of channels making up the audio data. When this is set to 1 it is mono, 2 is stereo, etc. */
+ drwav_uint16 channels;
+
+ /* The sample rate. Usually set to something like 44100. */
+ drwav_uint32 sampleRate;
+
+ /* Average bytes per second. You probably don't need this, but it's left here for informational purposes. */
+ drwav_uint32 avgBytesPerSec;
+
+ /* Block align. This is equal to the number of channels * bytes per sample. */
+ drwav_uint16 blockAlign;
+
+ /* Bits per sample. */
+ drwav_uint16 bitsPerSample;
+
+ /* The size of the extended data. Only used internally for validation, but left here for informational purposes. */
+ drwav_uint16 extendedSize;
+
+ /*
+ The number of valid bits per sample. When <formatTag> is equal to WAVE_FORMAT_EXTENSIBLE, <bitsPerSample>
+ is always rounded up to the nearest multiple of 8. This variable contains information about exactly how
+ many bits are valid per sample. Mainly used for informational purposes.
+ */
+ drwav_uint16 validBitsPerSample;
+
+ /* The channel mask. Not used at the moment. */
+ drwav_uint32 channelMask;
+
+ /* The sub-format, exactly as specified by the wave file. */
+ drwav_uint8 subFormat[16];
+} drwav_fmt;
+
+DRWAV_API drwav_uint16 drwav_fmt_get_format(const drwav_fmt* pFMT);
+
+
+/*
+Callback for when data is read. Return value is the number of bytes actually read.
+
+pUserData [in] The user data that was passed to drwav_init() and family.
+pBufferOut [out] The output buffer.
+bytesToRead [in] The number of bytes to read.
+
+Returns the number of bytes actually read.
+
+A return value of less than bytesToRead indicates the end of the stream. Do _not_ return from this callback until
+either the entire bytesToRead is filled or you have reached the end of the stream.
+*/
+typedef size_t (* drwav_read_proc)(void* pUserData, void* pBufferOut, size_t bytesToRead);
+
+/*
+Callback for when data is written. Returns value is the number of bytes actually written.
+
+pUserData [in] The user data that was passed to drwav_init_write() and family.
+pData [out] A pointer to the data to write.
+bytesToWrite [in] The number of bytes to write.
+
+Returns the number of bytes actually written.
+
+If the return value differs from bytesToWrite, it indicates an error.
+*/
+typedef size_t (* drwav_write_proc)(void* pUserData, const void* pData, size_t bytesToWrite);
+
+/*
+Callback for when data needs to be seeked.
+
+pUserData [in] The user data that was passed to drwav_init() and family.
+offset [in] The number of bytes to move, relative to the origin. Will never be negative.
+origin [in] The origin of the seek - the current position or the start of the stream.
+
+Returns whether or not the seek was successful.
+
+Whether or not it is relative to the beginning or current position is determined by the "origin" parameter which will be either drwav_seek_origin_start or
+drwav_seek_origin_current.
+*/
+typedef drwav_bool32 (* drwav_seek_proc)(void* pUserData, int offset, drwav_seek_origin origin);
+
+/*
+Callback for when drwav_init_ex() finds a chunk.
+
+pChunkUserData [in] The user data that was passed to the pChunkUserData parameter of drwav_init_ex() and family.
+onRead [in] A pointer to the function to call when reading.
+onSeek [in] A pointer to the function to call when seeking.
+pReadSeekUserData [in] The user data that was passed to the pReadSeekUserData parameter of drwav_init_ex() and family.
+pChunkHeader [in] A pointer to an object containing basic header information about the chunk. Use this to identify the chunk.
+container [in] Whether or not the WAV file is a RIFF or Wave64 container. If you're unsure of the difference, assume RIFF.
+pFMT [in] A pointer to the object containing the contents of the "fmt" chunk.
+
+Returns the number of bytes read + seeked.
+
+To read data from the chunk, call onRead(), passing in pReadSeekUserData as the first parameter. Do the same for seeking with onSeek(). The return value must
+be the total number of bytes you have read _plus_ seeked.
+
+Use the `container` argument to discriminate the fields in `pChunkHeader->id`. If the container is `drwav_container_riff` or `drwav_container_rf64` you should
+use `id.fourcc`, otherwise you should use `id.guid`.
+
+The `pFMT` parameter can be used to determine the data format of the wave file. Use `drwav_fmt_get_format()` to get the sample format, which will be one of the
+`DR_WAVE_FORMAT_*` identifiers.
+
+The read pointer will be sitting on the first byte after the chunk's header. You must not attempt to read beyond the boundary of the chunk.
+*/
+typedef drwav_uint64 (* drwav_chunk_proc)(void* pChunkUserData, drwav_read_proc onRead, drwav_seek_proc onSeek, void* pReadSeekUserData, const drwav_chunk_header* pChunkHeader, drwav_container container, const drwav_fmt* pFMT);
+
+typedef struct
+{
+ void* pUserData;
+ void* (* onMalloc)(size_t sz, void* pUserData);
+ void* (* onRealloc)(void* p, size_t sz, void* pUserData);
+ void (* onFree)(void* p, void* pUserData);
+} drwav_allocation_callbacks;
+
+/* Structure for internal use. Only used for loaders opened with drwav_init_memory(). */
+typedef struct
+{
+ const drwav_uint8* data;
+ size_t dataSize;
+ size_t currentReadPos;
+} drwav__memory_stream;
+
+/* Structure for internal use. Only used for writers opened with drwav_init_memory_write(). */
+typedef struct
+{
+ void** ppData;
+ size_t* pDataSize;
+ size_t dataSize;
+ size_t dataCapacity;
+ size_t currentWritePos;
+} drwav__memory_stream_write;
+
+typedef struct
+{
+ drwav_container container; /* RIFF, W64. */
+ drwav_uint32 format; /* DR_WAVE_FORMAT_* */
+ drwav_uint32 channels;
+ drwav_uint32 sampleRate;
+ drwav_uint32 bitsPerSample;
+} drwav_data_format;
+
+
+/* See the following for details on the 'smpl' chunk: https://sites.google.com/site/musicgapi/technical-documents/wav-file-format#smpl */
+typedef struct
+{
+ drwav_uint32 cuePointId;
+ drwav_uint32 type;
+ drwav_uint32 start;
+ drwav_uint32 end;
+ drwav_uint32 fraction;
+ drwav_uint32 playCount;
+} drwav_smpl_loop;
+
+ typedef struct
+{
+ drwav_uint32 manufacturer;
+ drwav_uint32 product;
+ drwav_uint32 samplePeriod;
+ drwav_uint32 midiUnityNotes;
+ drwav_uint32 midiPitchFraction;
+ drwav_uint32 smpteFormat;
+ drwav_uint32 smpteOffset;
+ drwav_uint32 numSampleLoops;
+ drwav_uint32 samplerData;
+ drwav_smpl_loop loops[DRWAV_MAX_SMPL_LOOPS];
+} drwav_smpl;
+
+typedef struct
+{
+ /* A pointer to the function to call when more data is needed. */
+ drwav_read_proc onRead;
+
+ /* A pointer to the function to call when data needs to be written. Only used when the drwav object is opened in write mode. */
+ drwav_write_proc onWrite;
+
+ /* A pointer to the function to call when the wav file needs to be seeked. */
+ drwav_seek_proc onSeek;
+
+ /* The user data to pass to callbacks. */
+ void* pUserData;
+
+ /* Allocation callbacks. */
+ drwav_allocation_callbacks allocationCallbacks;
+
+
+ /* Whether or not the WAV file is formatted as a standard RIFF file or W64. */
+ drwav_container container;
+
+
+ /* Structure containing format information exactly as specified by the wav file. */
+ drwav_fmt fmt;
+
+ /* The sample rate. Will be set to something like 44100. */
+ drwav_uint32 sampleRate;
+
+ /* The number of channels. This will be set to 1 for monaural streams, 2 for stereo, etc. */
+ drwav_uint16 channels;
+
+ /* The bits per sample. Will be set to something like 16, 24, etc. */
+ drwav_uint16 bitsPerSample;
+
+ /* Equal to fmt.formatTag, or the value specified by fmt.subFormat if fmt.formatTag is equal to 65534 (WAVE_FORMAT_EXTENSIBLE). */
+ drwav_uint16 translatedFormatTag;
+
+ /* The total number of PCM frames making up the audio data. */
+ drwav_uint64 totalPCMFrameCount;
+
+
+ /* The size in bytes of the data chunk. */
+ drwav_uint64 dataChunkDataSize;
+
+ /* The position in the stream of the first byte of the data chunk. This is used for seeking. */
+ drwav_uint64 dataChunkDataPos;
+
+ /* The number of bytes remaining in the data chunk. */
+ drwav_uint64 bytesRemaining;
+
+
+ /*
+ Only used in sequential write mode. Keeps track of the desired size of the "data" chunk at the point of initialization time. Always
+ set to 0 for non-sequential writes and when the drwav object is opened in read mode. Used for validation.
+ */
+ drwav_uint64 dataChunkDataSizeTargetWrite;
+
+ /* Keeps track of whether or not the wav writer was initialized in sequential mode. */
+ drwav_bool32 isSequentialWrite;
+
+
+ /* smpl chunk. */
+ drwav_smpl smpl;
+
+
+ /* A hack to avoid a DRWAV_MALLOC() when opening a decoder with drwav_init_memory(). */
+ drwav__memory_stream memoryStream;
+ drwav__memory_stream_write memoryStreamWrite;
+
+ /* Generic data for compressed formats. This data is shared across all block-compressed formats. */
+ struct
+ {
+ drwav_uint64 iCurrentPCMFrame; /* The index of the next PCM frame that will be read by drwav_read_*(). This is used with "totalPCMFrameCount" to ensure we don't read excess samples at the end of the last block. */
+ } compressed;
+
+ /* Microsoft ADPCM specific data. */
+ struct
+ {
+ drwav_uint32 bytesRemainingInBlock;
+ drwav_uint16 predictor[2];
+ drwav_int32 delta[2];
+ drwav_int32 cachedFrames[4]; /* Samples are stored in this cache during decoding. */
+ drwav_uint32 cachedFrameCount;
+ drwav_int32 prevFrames[2][2]; /* The previous 2 samples for each channel (2 channels at most). */
+ } msadpcm;
+
+ /* IMA ADPCM specific data. */
+ struct
+ {
+ drwav_uint32 bytesRemainingInBlock;
+ drwav_int32 predictor[2];
+ drwav_int32 stepIndex[2];
+ drwav_int32 cachedFrames[16]; /* Samples are stored in this cache during decoding. */
+ drwav_uint32 cachedFrameCount;
+ } ima;
+} drwav;
+
+
+/*
+Initializes a pre-allocated drwav object for reading.
+
+pWav [out] A pointer to the drwav object being initialized.
+onRead [in] The function to call when data needs to be read from the client.
+onSeek [in] The function to call when the read position of the client data needs to move.
+onChunk [in, optional] The function to call when a chunk is enumerated at initialized time.
+pUserData, pReadSeekUserData [in, optional] A pointer to application defined data that will be passed to onRead and onSeek.
+pChunkUserData [in, optional] A pointer to application defined data that will be passed to onChunk.
+flags [in, optional] A set of flags for controlling how things are loaded.
+
+Returns true if successful; false otherwise.
+
+Close the loader with drwav_uninit().
+
+This is the lowest level function for initializing a WAV file. You can also use drwav_init_file() and drwav_init_memory()
+to open the stream from a file or from a block of memory respectively.
+
+Possible values for flags:
+ DRWAV_SEQUENTIAL: Never perform a backwards seek while loading. This disables the chunk callback and will cause this function
+ to return as soon as the data chunk is found. Any chunks after the data chunk will be ignored.
+
+drwav_init() is equivalent to "drwav_init_ex(pWav, onRead, onSeek, NULL, pUserData, NULL, 0);".
+
+The onChunk callback is not called for the WAVE or FMT chunks. The contents of the FMT chunk can be read from pWav->fmt
+after the function returns.
+
+See also: drwav_init_file(), drwav_init_memory(), drwav_uninit()
+*/
+DRWAV_API drwav_bool32 drwav_init(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_ex(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, drwav_chunk_proc onChunk, void* pReadSeekUserData, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks);
+
+/*
+Initializes a pre-allocated drwav object for writing.
+
+onWrite [in] The function to call when data needs to be written.
+onSeek [in] The function to call when the write position needs to move.
+pUserData [in, optional] A pointer to application defined data that will be passed to onWrite and onSeek.
+
+Returns true if successful; false otherwise.
+
+Close the writer with drwav_uninit().
+
+This is the lowest level function for initializing a WAV file. You can also use drwav_init_file_write() and drwav_init_memory_write()
+to open the stream from a file or from a block of memory respectively.
+
+If the total sample count is known, you can use drwav_init_write_sequential(). This avoids the need for dr_wav to perform
+a post-processing step for storing the total sample count and the size of the data chunk which requires a backwards seek.
+
+See also: drwav_init_file_write(), drwav_init_memory_write(), drwav_uninit()
+*/
+DRWAV_API drwav_bool32 drwav_init_write(drwav* pWav, const drwav_data_format* pFormat, drwav_write_proc onWrite, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_write_sequential(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_write_proc onWrite, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_write_sequential_pcm_frames(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, drwav_write_proc onWrite, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks);
+
+/*
+Utility function to determine the target size of the entire data to be written (including all headers and chunks).
+
+Returns the target size in bytes.
+
+Useful if the application needs to know the size to allocate.
+
+Only writing to the RIFF chunk and one data chunk is currently supported.
+
+See also: drwav_init_write(), drwav_init_file_write(), drwav_init_memory_write()
+*/
+DRWAV_API drwav_uint64 drwav_target_write_size_bytes(const drwav_data_format* pFormat, drwav_uint64 totalSampleCount);
+
+/*
+Uninitializes the given drwav object.
+
+Use this only for objects initialized with drwav_init*() functions (drwav_init(), drwav_init_ex(), drwav_init_write(), drwav_init_write_sequential()).
+*/
+DRWAV_API drwav_result drwav_uninit(drwav* pWav);
+
+
+/*
+Reads raw audio data.
+
+This is the lowest level function for reading audio data. It simply reads the given number of
+bytes of the raw internal sample data.
+
+Consider using drwav_read_pcm_frames_s16(), drwav_read_pcm_frames_s32() or drwav_read_pcm_frames_f32() for
+reading sample data in a consistent format.
+
+pBufferOut can be NULL in which case a seek will be performed.
+
+Returns the number of bytes actually read.
+*/
+DRWAV_API size_t drwav_read_raw(drwav* pWav, size_t bytesToRead, void* pBufferOut);
+
+/*
+Reads up to the specified number of PCM frames from the WAV file.
+
+The output data will be in the file's internal format, converted to native-endian byte order. Use
+drwav_read_pcm_frames_s16/f32/s32() to read data in a specific format.
+
+If the return value is less than <framesToRead> it means the end of the file has been reached or
+you have requested more PCM frames than can possibly fit in the output buffer.
+
+This function will only work when sample data is of a fixed size and uncompressed. If you are
+using a compressed format consider using drwav_read_raw() or drwav_read_pcm_frames_s16/s32/f32().
+
+pBufferOut can be NULL in which case a seek will be performed.
+*/
+DRWAV_API drwav_uint64 drwav_read_pcm_frames(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut);
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_le(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut);
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_be(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut);
+
+/*
+Seeks to the given PCM frame.
+
+Returns true if successful; false otherwise.
+*/
+DRWAV_API drwav_bool32 drwav_seek_to_pcm_frame(drwav* pWav, drwav_uint64 targetFrameIndex);
+
+
+/*
+Writes raw audio data.
+
+Returns the number of bytes actually written. If this differs from bytesToWrite, it indicates an error.
+*/
+DRWAV_API size_t drwav_write_raw(drwav* pWav, size_t bytesToWrite, const void* pData);
+
+/*
+Writes PCM frames.
+
+Returns the number of PCM frames written.
+
+Input samples need to be in native-endian byte order. On big-endian architectures the input data will be converted to
+little-endian. Use drwav_write_raw() to write raw audio data without performing any conversion.
+*/
+DRWAV_API drwav_uint64 drwav_write_pcm_frames(drwav* pWav, drwav_uint64 framesToWrite, const void* pData);
+DRWAV_API drwav_uint64 drwav_write_pcm_frames_le(drwav* pWav, drwav_uint64 framesToWrite, const void* pData);
+DRWAV_API drwav_uint64 drwav_write_pcm_frames_be(drwav* pWav, drwav_uint64 framesToWrite, const void* pData);
+
+
+/* Conversion Utilities */
+#ifndef DR_WAV_NO_CONVERSION_API
+
+/*
+Reads a chunk of audio data and converts it to signed 16-bit PCM samples.
+
+pBufferOut can be NULL in which case a seek will be performed.
+
+Returns the number of PCM frames actually read.
+
+If the return value is less than <framesToRead> it means the end of the file has been reached.
+*/
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut);
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16le(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut);
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16be(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut);
+
+/* Low-level function for converting unsigned 8-bit PCM samples to signed 16-bit PCM samples. */
+DRWAV_API void drwav_u8_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount);
+
+/* Low-level function for converting signed 24-bit PCM samples to signed 16-bit PCM samples. */
+DRWAV_API void drwav_s24_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount);
+
+/* Low-level function for converting signed 32-bit PCM samples to signed 16-bit PCM samples. */
+DRWAV_API void drwav_s32_to_s16(drwav_int16* pOut, const drwav_int32* pIn, size_t sampleCount);
+
+/* Low-level function for converting IEEE 32-bit floating point samples to signed 16-bit PCM samples. */
+DRWAV_API void drwav_f32_to_s16(drwav_int16* pOut, const float* pIn, size_t sampleCount);
+
+/* Low-level function for converting IEEE 64-bit floating point samples to signed 16-bit PCM samples. */
+DRWAV_API void drwav_f64_to_s16(drwav_int16* pOut, const double* pIn, size_t sampleCount);
+
+/* Low-level function for converting A-law samples to signed 16-bit PCM samples. */
+DRWAV_API void drwav_alaw_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount);
+
+/* Low-level function for converting u-law samples to signed 16-bit PCM samples. */
+DRWAV_API void drwav_mulaw_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount);
+
+
+/*
+Reads a chunk of audio data and converts it to IEEE 32-bit floating point samples.
+
+pBufferOut can be NULL in which case a seek will be performed.
+
+Returns the number of PCM frames actually read.
+
+If the return value is less than <framesToRead> it means the end of the file has been reached.
+*/
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut);
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32le(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut);
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32be(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut);
+
+/* Low-level function for converting unsigned 8-bit PCM samples to IEEE 32-bit floating point samples. */
+DRWAV_API void drwav_u8_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount);
+
+/* Low-level function for converting signed 16-bit PCM samples to IEEE 32-bit floating point samples. */
+DRWAV_API void drwav_s16_to_f32(float* pOut, const drwav_int16* pIn, size_t sampleCount);
+
+/* Low-level function for converting signed 24-bit PCM samples to IEEE 32-bit floating point samples. */
+DRWAV_API void drwav_s24_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount);
+
+/* Low-level function for converting signed 32-bit PCM samples to IEEE 32-bit floating point samples. */
+DRWAV_API void drwav_s32_to_f32(float* pOut, const drwav_int32* pIn, size_t sampleCount);
+
+/* Low-level function for converting IEEE 64-bit floating point samples to IEEE 32-bit floating point samples. */
+DRWAV_API void drwav_f64_to_f32(float* pOut, const double* pIn, size_t sampleCount);
+
+/* Low-level function for converting A-law samples to IEEE 32-bit floating point samples. */
+DRWAV_API void drwav_alaw_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount);
+
+/* Low-level function for converting u-law samples to IEEE 32-bit floating point samples. */
+DRWAV_API void drwav_mulaw_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount);
+
+
+/*
+Reads a chunk of audio data and converts it to signed 32-bit PCM samples.
+
+pBufferOut can be NULL in which case a seek will be performed.
+
+Returns the number of PCM frames actually read.
+
+If the return value is less than <framesToRead> it means the end of the file has been reached.
+*/
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut);
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32le(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut);
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32be(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut);
+
+/* Low-level function for converting unsigned 8-bit PCM samples to signed 32-bit PCM samples. */
+DRWAV_API void drwav_u8_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount);
+
+/* Low-level function for converting signed 16-bit PCM samples to signed 32-bit PCM samples. */
+DRWAV_API void drwav_s16_to_s32(drwav_int32* pOut, const drwav_int16* pIn, size_t sampleCount);
+
+/* Low-level function for converting signed 24-bit PCM samples to signed 32-bit PCM samples. */
+DRWAV_API void drwav_s24_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount);
+
+/* Low-level function for converting IEEE 32-bit floating point samples to signed 32-bit PCM samples. */
+DRWAV_API void drwav_f32_to_s32(drwav_int32* pOut, const float* pIn, size_t sampleCount);
+
+/* Low-level function for converting IEEE 64-bit floating point samples to signed 32-bit PCM samples. */
+DRWAV_API void drwav_f64_to_s32(drwav_int32* pOut, const double* pIn, size_t sampleCount);
+
+/* Low-level function for converting A-law samples to signed 32-bit PCM samples. */
+DRWAV_API void drwav_alaw_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount);
+
+/* Low-level function for converting u-law samples to signed 32-bit PCM samples. */
+DRWAV_API void drwav_mulaw_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount);
+
+#endif /* DR_WAV_NO_CONVERSION_API */
+
+
+/* High-Level Convenience Helpers */
+
+#ifndef DR_WAV_NO_STDIO
+/*
+Helper for initializing a wave file for reading using stdio.
+
+This holds the internal FILE object until drwav_uninit() is called. Keep this in mind if you're caching drwav
+objects because the operating system may restrict the number of file handles an application can have open at
+any given time.
+*/
+DRWAV_API drwav_bool32 drwav_init_file(drwav* pWav, const char* filename, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_file_ex(drwav* pWav, const char* filename, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_file_w(drwav* pWav, const wchar_t* filename, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_file_ex_w(drwav* pWav, const wchar_t* filename, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks);
+
+/*
+Helper for initializing a wave file for writing using stdio.
+
+This holds the internal FILE object until drwav_uninit() is called. Keep this in mind if you're caching drwav
+objects because the operating system may restrict the number of file handles an application can have open at
+any given time.
+*/
+DRWAV_API drwav_bool32 drwav_init_file_write(drwav* pWav, const char* filename, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_file_write_sequential(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_file_write_sequential_pcm_frames(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_file_write_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_file_write_sequential_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_file_write_sequential_pcm_frames_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks);
+#endif /* DR_WAV_NO_STDIO */
+
+/*
+Helper for initializing a loader from a pre-allocated memory buffer.
+
+This does not create a copy of the data. It is up to the application to ensure the buffer remains valid for
+the lifetime of the drwav object.
+
+The buffer should contain the contents of the entire wave file, not just the sample data.
+*/
+DRWAV_API drwav_bool32 drwav_init_memory(drwav* pWav, const void* data, size_t dataSize, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_memory_ex(drwav* pWav, const void* data, size_t dataSize, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks);
+
+/*
+Helper for initializing a writer which outputs data to a memory buffer.
+
+dr_wav will manage the memory allocations, however it is up to the caller to free the data with drwav_free().
+
+The buffer will remain allocated even after drwav_uninit() is called. The buffer should not be considered valid
+until after drwav_uninit() has been called.
+*/
+DRWAV_API drwav_bool32 drwav_init_memory_write(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_memory_write_sequential(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_memory_write_sequential_pcm_frames(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks);
+
+
+#ifndef DR_WAV_NO_CONVERSION_API
+/*
+Opens and reads an entire wav file in a single operation.
+
+The return value is a heap-allocated buffer containing the audio data. Use drwav_free() to free the buffer.
+*/
+DRWAV_API drwav_int16* drwav_open_and_read_pcm_frames_s16(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API float* drwav_open_and_read_pcm_frames_f32(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_int32* drwav_open_and_read_pcm_frames_s32(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks);
+#ifndef DR_WAV_NO_STDIO
+/*
+Opens and decodes an entire wav file in a single operation.
+
+The return value is a heap-allocated buffer containing the audio data. Use drwav_free() to free the buffer.
+*/
+DRWAV_API drwav_int16* drwav_open_file_and_read_pcm_frames_s16(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API float* drwav_open_file_and_read_pcm_frames_f32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_int32* drwav_open_file_and_read_pcm_frames_s32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_int16* drwav_open_file_and_read_pcm_frames_s16_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API float* drwav_open_file_and_read_pcm_frames_f32_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_int32* drwav_open_file_and_read_pcm_frames_s32_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks);
+#endif
+/*
+Opens and decodes an entire wav file from a block of memory in a single operation.
+
+The return value is a heap-allocated buffer containing the audio data. Use drwav_free() to free the buffer.
+*/
+DRWAV_API drwav_int16* drwav_open_memory_and_read_pcm_frames_s16(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API float* drwav_open_memory_and_read_pcm_frames_f32(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_int32* drwav_open_memory_and_read_pcm_frames_s32(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks);
+#endif
+
+/* Frees data that was allocated internally by dr_wav. */
+DRWAV_API void drwav_free(void* p, const drwav_allocation_callbacks* pAllocationCallbacks);
+
+/* Converts bytes from a wav stream to a sized type of native endian. */
+DRWAV_API drwav_uint16 drwav_bytes_to_u16(const drwav_uint8* data);
+DRWAV_API drwav_int16 drwav_bytes_to_s16(const drwav_uint8* data);
+DRWAV_API drwav_uint32 drwav_bytes_to_u32(const drwav_uint8* data);
+DRWAV_API drwav_int32 drwav_bytes_to_s32(const drwav_uint8* data);
+DRWAV_API drwav_uint64 drwav_bytes_to_u64(const drwav_uint8* data);
+DRWAV_API drwav_int64 drwav_bytes_to_s64(const drwav_uint8* data);
+
+/* Compares a GUID for the purpose of checking the type of a Wave64 chunk. */
+DRWAV_API drwav_bool32 drwav_guid_equal(const drwav_uint8 a[16], const drwav_uint8 b[16]);
+
+/* Compares a four-character-code for the purpose of checking the type of a RIFF chunk. */
+DRWAV_API drwav_bool32 drwav_fourcc_equal(const drwav_uint8* a, const char* b);
+
+#ifdef __cplusplus
+}
+#endif
+#endif /* dr_wav_h */
+
+
+/************************************************************************************************************************************************************
+ ************************************************************************************************************************************************************
+
+ IMPLEMENTATION
+
+ ************************************************************************************************************************************************************
+ ************************************************************************************************************************************************************/
+#if defined(DR_WAV_IMPLEMENTATION) || defined(DRWAV_IMPLEMENTATION)
+#ifndef dr_wav_c
+#define dr_wav_c
+
+#include <stdlib.h>
+#include <string.h> /* For memcpy(), memset() */
+#include <limits.h> /* For INT_MAX */
+
+#ifndef DR_WAV_NO_STDIO
+#include <stdio.h>
+#include <wchar.h>
+#endif
+
+/* Standard library stuff. */
+#ifndef DRWAV_ASSERT
+#include <assert.h>
+#define DRWAV_ASSERT(expression) assert(expression)
+#endif
+#ifndef DRWAV_MALLOC
+#define DRWAV_MALLOC(sz) malloc((sz))
+#endif
+#ifndef DRWAV_REALLOC
+#define DRWAV_REALLOC(p, sz) realloc((p), (sz))
+#endif
+#ifndef DRWAV_FREE
+#define DRWAV_FREE(p) free((p))
+#endif
+#ifndef DRWAV_COPY_MEMORY
+#define DRWAV_COPY_MEMORY(dst, src, sz) memcpy((dst), (src), (sz))
+#endif
+#ifndef DRWAV_ZERO_MEMORY
+#define DRWAV_ZERO_MEMORY(p, sz) memset((p), 0, (sz))
+#endif
+#ifndef DRWAV_ZERO_OBJECT
+#define DRWAV_ZERO_OBJECT(p) DRWAV_ZERO_MEMORY((p), sizeof(*p))
+#endif
+
+#define drwav_countof(x) (sizeof(x) / sizeof(x[0]))
+#define drwav_align(x, a) ((((x) + (a) - 1) / (a)) * (a))
+#define drwav_min(a, b) (((a) < (b)) ? (a) : (b))
+#define drwav_max(a, b) (((a) > (b)) ? (a) : (b))
+#define drwav_clamp(x, lo, hi) (drwav_max((lo), drwav_min((hi), (x))))
+
+#define DRWAV_MAX_SIMD_VECTOR_SIZE 64 /* 64 for AVX-512 in the future. */
+
+/* CPU architecture. */
+#if defined(__x86_64__) || defined(_M_X64)
+ #define DRWAV_X64
+#elif defined(__i386) || defined(_M_IX86)
+ #define DRWAV_X86
+#elif defined(__arm__) || defined(_M_ARM)
+ #define DRWAV_ARM
+#endif
+
+#ifdef _MSC_VER
+ #define DRWAV_INLINE __forceinline
+#elif defined(__GNUC__)
+ /*
+ I've had a bug report where GCC is emitting warnings about functions possibly not being inlineable. This warning happens when
+ the __attribute__((always_inline)) attribute is defined without an "inline" statement. I think therefore there must be some
+ case where "__inline__" is not always defined, thus the compiler emitting these warnings. When using -std=c89 or -ansi on the
+ command line, we cannot use the "inline" keyword and instead need to use "__inline__". In an attempt to work around this issue
+ I am using "__inline__" only when we're compiling in strict ANSI mode.
+ */
+ #if defined(__STRICT_ANSI__)
+ #define DRWAV_INLINE __inline__ __attribute__((always_inline))
+ #else
+ #define DRWAV_INLINE inline __attribute__((always_inline))
+ #endif
+#elif defined(__WATCOMC__)
+ #define DRWAV_INLINE __inline
+#else
+ #define DRWAV_INLINE
+#endif
+
+#if defined(SIZE_MAX)
+ #define DRWAV_SIZE_MAX SIZE_MAX
+#else
+ #if defined(_WIN64) || defined(_LP64) || defined(__LP64__)
+ #define DRWAV_SIZE_MAX ((drwav_uint64)0xFFFFFFFFFFFFFFFF)
+ #else
+ #define DRWAV_SIZE_MAX 0xFFFFFFFF
+ #endif
+#endif
+
+#if defined(_MSC_VER) && _MSC_VER >= 1400
+ #define DRWAV_HAS_BYTESWAP16_INTRINSIC
+ #define DRWAV_HAS_BYTESWAP32_INTRINSIC
+ #define DRWAV_HAS_BYTESWAP64_INTRINSIC
+#elif defined(__clang__)
+ #if defined(__has_builtin)
+ #if __has_builtin(__builtin_bswap16)
+ #define DRWAV_HAS_BYTESWAP16_INTRINSIC
+ #endif
+ #if __has_builtin(__builtin_bswap32)
+ #define DRWAV_HAS_BYTESWAP32_INTRINSIC
+ #endif
+ #if __has_builtin(__builtin_bswap64)
+ #define DRWAV_HAS_BYTESWAP64_INTRINSIC
+ #endif
+ #endif
+#elif defined(__GNUC__)
+ #if ((__GNUC__ > 4) || (__GNUC__ == 4 && __GNUC_MINOR__ >= 3))
+ #define DRWAV_HAS_BYTESWAP32_INTRINSIC
+ #define DRWAV_HAS_BYTESWAP64_INTRINSIC
+ #endif
+ #if ((__GNUC__ > 4) || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8))
+ #define DRWAV_HAS_BYTESWAP16_INTRINSIC
+ #endif
+#endif
+
+DRWAV_API void drwav_version(drwav_uint32* pMajor, drwav_uint32* pMinor, drwav_uint32* pRevision)
+{
+ if (pMajor) {
+ *pMajor = DRWAV_VERSION_MAJOR;
+ }
+
+ if (pMinor) {
+ *pMinor = DRWAV_VERSION_MINOR;
+ }
+
+ if (pRevision) {
+ *pRevision = DRWAV_VERSION_REVISION;
+ }
+}
+
+DRWAV_API const char* drwav_version_string(void)
+{
+ return DRWAV_VERSION_STRING;
+}
+
+/*
+These limits are used for basic validation when initializing the decoder. If you exceed these limits, first of all: what on Earth are
+you doing?! (Let me know, I'd be curious!) Second, you can adjust these by #define-ing them before the dr_wav implementation.
+*/
+#ifndef DRWAV_MAX_SAMPLE_RATE
+#define DRWAV_MAX_SAMPLE_RATE 384000
+#endif
+#ifndef DRWAV_MAX_CHANNELS
+#define DRWAV_MAX_CHANNELS 256
+#endif
+#ifndef DRWAV_MAX_BITS_PER_SAMPLE
+#define DRWAV_MAX_BITS_PER_SAMPLE 64
+#endif
+
+static const drwav_uint8 drwavGUID_W64_RIFF[16] = {0x72,0x69,0x66,0x66, 0x2E,0x91, 0xCF,0x11, 0xA5,0xD6, 0x28,0xDB,0x04,0xC1,0x00,0x00}; /* 66666972-912E-11CF-A5D6-28DB04C10000 */
+static const drwav_uint8 drwavGUID_W64_WAVE[16] = {0x77,0x61,0x76,0x65, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 65766177-ACF3-11D3-8CD1-00C04F8EDB8A */
+/*static const drwav_uint8 drwavGUID_W64_JUNK[16] = {0x6A,0x75,0x6E,0x6B, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A};*/ /* 6B6E756A-ACF3-11D3-8CD1-00C04F8EDB8A */
+static const drwav_uint8 drwavGUID_W64_FMT [16] = {0x66,0x6D,0x74,0x20, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 20746D66-ACF3-11D3-8CD1-00C04F8EDB8A */
+static const drwav_uint8 drwavGUID_W64_FACT[16] = {0x66,0x61,0x63,0x74, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 74636166-ACF3-11D3-8CD1-00C04F8EDB8A */
+static const drwav_uint8 drwavGUID_W64_DATA[16] = {0x64,0x61,0x74,0x61, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 61746164-ACF3-11D3-8CD1-00C04F8EDB8A */
+static const drwav_uint8 drwavGUID_W64_SMPL[16] = {0x73,0x6D,0x70,0x6C, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 6C706D73-ACF3-11D3-8CD1-00C04F8EDB8A */
+
+static DRWAV_INLINE drwav_bool32 drwav__guid_equal(const drwav_uint8 a[16], const drwav_uint8 b[16])
+{
+ int i;
+ for (i = 0; i < 16; i += 1) {
+ if (a[i] != b[i]) {
+ return DRWAV_FALSE;
+ }
+ }
+
+ return DRWAV_TRUE;
+}
+
+static DRWAV_INLINE drwav_bool32 drwav__fourcc_equal(const drwav_uint8* a, const char* b)
+{
+ return
+ a[0] == b[0] &&
+ a[1] == b[1] &&
+ a[2] == b[2] &&
+ a[3] == b[3];
+}
+
+
+
+static DRWAV_INLINE int drwav__is_little_endian(void)
+{
+#if defined(DRWAV_X86) || defined(DRWAV_X64)
+ return DRWAV_TRUE;
+#elif defined(__BYTE_ORDER) && defined(__LITTLE_ENDIAN) && __BYTE_ORDER == __LITTLE_ENDIAN
+ return DRWAV_TRUE;
+#else
+ int n = 1;
+ return (*(char*)&n) == 1;
+#endif
+}
+
+static DRWAV_INLINE drwav_uint16 drwav__bytes_to_u16(const drwav_uint8* data)
+{
+ return (data[0] << 0) | (data[1] << 8);
+}
+
+static DRWAV_INLINE drwav_int16 drwav__bytes_to_s16(const drwav_uint8* data)
+{
+ return (short)drwav__bytes_to_u16(data);
+}
+
+static DRWAV_INLINE drwav_uint32 drwav__bytes_to_u32(const drwav_uint8* data)
+{
+ return (data[0] << 0) | (data[1] << 8) | (data[2] << 16) | (data[3] << 24);
+}
+
+static DRWAV_INLINE drwav_int32 drwav__bytes_to_s32(const drwav_uint8* data)
+{
+ return (drwav_int32)drwav__bytes_to_u32(data);
+}
+
+static DRWAV_INLINE drwav_uint64 drwav__bytes_to_u64(const drwav_uint8* data)
+{
+ return
+ ((drwav_uint64)data[0] << 0) | ((drwav_uint64)data[1] << 8) | ((drwav_uint64)data[2] << 16) | ((drwav_uint64)data[3] << 24) |
+ ((drwav_uint64)data[4] << 32) | ((drwav_uint64)data[5] << 40) | ((drwav_uint64)data[6] << 48) | ((drwav_uint64)data[7] << 56);
+}
+
+static DRWAV_INLINE drwav_int64 drwav__bytes_to_s64(const drwav_uint8* data)
+{
+ return (drwav_int64)drwav__bytes_to_u64(data);
+}
+
+static DRWAV_INLINE void drwav__bytes_to_guid(const drwav_uint8* data, drwav_uint8* guid)
+{
+ int i;
+ for (i = 0; i < 16; ++i) {
+ guid[i] = data[i];
+ }
+}
+
+
+static DRWAV_INLINE drwav_uint16 drwav__bswap16(drwav_uint16 n)
+{
+#ifdef DRWAV_HAS_BYTESWAP16_INTRINSIC
+ #if defined(_MSC_VER)
+ return _byteswap_ushort(n);
+ #elif defined(__GNUC__) || defined(__clang__)
+ return __builtin_bswap16(n);
+ #else
+ #error "This compiler does not support the byte swap intrinsic."
+ #endif
+#else
+ return ((n & 0xFF00) >> 8) |
+ ((n & 0x00FF) << 8);
+#endif
+}
+
+static DRWAV_INLINE drwav_uint32 drwav__bswap32(drwav_uint32 n)
+{
+#ifdef DRWAV_HAS_BYTESWAP32_INTRINSIC
+ #if defined(_MSC_VER)
+ return _byteswap_ulong(n);
+ #elif defined(__GNUC__) || defined(__clang__)
+ #if defined(DRWAV_ARM) && (defined(__ARM_ARCH) && __ARM_ARCH >= 6) && !defined(DRWAV_64BIT) /* <-- 64-bit inline assembly has not been tested, so disabling for now. */
+ /* Inline assembly optimized implementation for ARM. In my testing, GCC does not generate optimized code with __builtin_bswap32(). */
+ drwav_uint32 r;
+ __asm__ __volatile__ (
+ #if defined(DRWAV_64BIT)
+ "rev %w[out], %w[in]" : [out]"=r"(r) : [in]"r"(n) /* <-- This is untested. If someone in the community could test this, that would be appreciated! */
+ #else
+ "rev %[out], %[in]" : [out]"=r"(r) : [in]"r"(n)
+ #endif
+ );
+ return r;
+ #else
+ return __builtin_bswap32(n);
+ #endif
+ #else
+ #error "This compiler does not support the byte swap intrinsic."
+ #endif
+#else
+ return ((n & 0xFF000000) >> 24) |
+ ((n & 0x00FF0000) >> 8) |
+ ((n & 0x0000FF00) << 8) |
+ ((n & 0x000000FF) << 24);
+#endif
+}
+
+static DRWAV_INLINE drwav_uint64 drwav__bswap64(drwav_uint64 n)
+{
+#ifdef DRWAV_HAS_BYTESWAP64_INTRINSIC
+ #if defined(_MSC_VER)
+ return _byteswap_uint64(n);
+ #elif defined(__GNUC__) || defined(__clang__)
+ return __builtin_bswap64(n);
+ #else
+ #error "This compiler does not support the byte swap intrinsic."
+ #endif
+#else
+ /* Weird "<< 32" bitshift is required for C89 because it doesn't support 64-bit constants. Should be optimized out by a good compiler. */
+ return ((n & ((drwav_uint64)0xFF000000 << 32)) >> 56) |
+ ((n & ((drwav_uint64)0x00FF0000 << 32)) >> 40) |
+ ((n & ((drwav_uint64)0x0000FF00 << 32)) >> 24) |
+ ((n & ((drwav_uint64)0x000000FF << 32)) >> 8) |
+ ((n & ((drwav_uint64)0xFF000000 )) << 8) |
+ ((n & ((drwav_uint64)0x00FF0000 )) << 24) |
+ ((n & ((drwav_uint64)0x0000FF00 )) << 40) |
+ ((n & ((drwav_uint64)0x000000FF )) << 56);
+#endif
+}
+
+
+static DRWAV_INLINE drwav_int16 drwav__bswap_s16(drwav_int16 n)
+{
+ return (drwav_int16)drwav__bswap16((drwav_uint16)n);
+}
+
+static DRWAV_INLINE void drwav__bswap_samples_s16(drwav_int16* pSamples, drwav_uint64 sampleCount)
+{
+ drwav_uint64 iSample;
+ for (iSample = 0; iSample < sampleCount; iSample += 1) {
+ pSamples[iSample] = drwav__bswap_s16(pSamples[iSample]);
+ }
+}
+
+
+static DRWAV_INLINE void drwav__bswap_s24(drwav_uint8* p)
+{
+ drwav_uint8 t;
+ t = p[0];
+ p[0] = p[2];
+ p[2] = t;
+}
+
+static DRWAV_INLINE void drwav__bswap_samples_s24(drwav_uint8* pSamples, drwav_uint64 sampleCount)
+{
+ drwav_uint64 iSample;
+ for (iSample = 0; iSample < sampleCount; iSample += 1) {
+ drwav_uint8* pSample = pSamples + (iSample*3);
+ drwav__bswap_s24(pSample);
+ }
+}
+
+
+static DRWAV_INLINE drwav_int32 drwav__bswap_s32(drwav_int32 n)
+{
+ return (drwav_int32)drwav__bswap32((drwav_uint32)n);
+}
+
+static DRWAV_INLINE void drwav__bswap_samples_s32(drwav_int32* pSamples, drwav_uint64 sampleCount)
+{
+ drwav_uint64 iSample;
+ for (iSample = 0; iSample < sampleCount; iSample += 1) {
+ pSamples[iSample] = drwav__bswap_s32(pSamples[iSample]);
+ }
+}
+
+
+static DRWAV_INLINE float drwav__bswap_f32(float n)
+{
+ union {
+ drwav_uint32 i;
+ float f;
+ } x;
+ x.f = n;
+ x.i = drwav__bswap32(x.i);
+
+ return x.f;
+}
+
+static DRWAV_INLINE void drwav__bswap_samples_f32(float* pSamples, drwav_uint64 sampleCount)
+{
+ drwav_uint64 iSample;
+ for (iSample = 0; iSample < sampleCount; iSample += 1) {
+ pSamples[iSample] = drwav__bswap_f32(pSamples[iSample]);
+ }
+}
+
+
+static DRWAV_INLINE double drwav__bswap_f64(double n)
+{
+ union {
+ drwav_uint64 i;
+ double f;
+ } x;
+ x.f = n;
+ x.i = drwav__bswap64(x.i);
+
+ return x.f;
+}
+
+static DRWAV_INLINE void drwav__bswap_samples_f64(double* pSamples, drwav_uint64 sampleCount)
+{
+ drwav_uint64 iSample;
+ for (iSample = 0; iSample < sampleCount; iSample += 1) {
+ pSamples[iSample] = drwav__bswap_f64(pSamples[iSample]);
+ }
+}
+
+
+static DRWAV_INLINE void drwav__bswap_samples_pcm(void* pSamples, drwav_uint64 sampleCount, drwav_uint32 bytesPerSample)
+{
+ /* Assumes integer PCM. Floating point PCM is done in drwav__bswap_samples_ieee(). */
+ switch (bytesPerSample)
+ {
+ case 2: /* s16, s12 (loosely packed) */
+ {
+ drwav__bswap_samples_s16((drwav_int16*)pSamples, sampleCount);
+ } break;
+ case 3: /* s24 */
+ {
+ drwav__bswap_samples_s24((drwav_uint8*)pSamples, sampleCount);
+ } break;
+ case 4: /* s32 */
+ {
+ drwav__bswap_samples_s32((drwav_int32*)pSamples, sampleCount);
+ } break;
+ default:
+ {
+ /* Unsupported format. */
+ DRWAV_ASSERT(DRWAV_FALSE);
+ } break;
+ }
+}
+
+static DRWAV_INLINE void drwav__bswap_samples_ieee(void* pSamples, drwav_uint64 sampleCount, drwav_uint32 bytesPerSample)
+{
+ switch (bytesPerSample)
+ {
+ #if 0 /* Contributions welcome for f16 support. */
+ case 2: /* f16 */
+ {
+ drwav__bswap_samples_f16((drwav_float16*)pSamples, sampleCount);
+ } break;
+ #endif
+ case 4: /* f32 */
+ {
+ drwav__bswap_samples_f32((float*)pSamples, sampleCount);
+ } break;
+ case 8: /* f64 */
+ {
+ drwav__bswap_samples_f64((double*)pSamples, sampleCount);
+ } break;
+ default:
+ {
+ /* Unsupported format. */
+ DRWAV_ASSERT(DRWAV_FALSE);
+ } break;
+ }
+}
+
+static DRWAV_INLINE void drwav__bswap_samples(void* pSamples, drwav_uint64 sampleCount, drwav_uint32 bytesPerSample, drwav_uint16 format)
+{
+ switch (format)
+ {
+ case DR_WAVE_FORMAT_PCM:
+ {
+ drwav__bswap_samples_pcm(pSamples, sampleCount, bytesPerSample);
+ } break;
+
+ case DR_WAVE_FORMAT_IEEE_FLOAT:
+ {
+ drwav__bswap_samples_ieee(pSamples, sampleCount, bytesPerSample);
+ } break;
+
+ case DR_WAVE_FORMAT_ALAW:
+ case DR_WAVE_FORMAT_MULAW:
+ {
+ drwav__bswap_samples_s16((drwav_int16*)pSamples, sampleCount);
+ } break;
+
+ case DR_WAVE_FORMAT_ADPCM:
+ case DR_WAVE_FORMAT_DVI_ADPCM:
+ default:
+ {
+ /* Unsupported format. */
+ DRWAV_ASSERT(DRWAV_FALSE);
+ } break;
+ }
+}
+
+
+static void* drwav__malloc_default(size_t sz, void* pUserData)
+{
+ (void)pUserData;
+ return DRWAV_MALLOC(sz);
+}
+
+static void* drwav__realloc_default(void* p, size_t sz, void* pUserData)
+{
+ (void)pUserData;
+ return DRWAV_REALLOC(p, sz);
+}
+
+static void drwav__free_default(void* p, void* pUserData)
+{
+ (void)pUserData;
+ DRWAV_FREE(p);
+}
+
+
+static void* drwav__malloc_from_callbacks(size_t sz, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (pAllocationCallbacks == NULL) {
+ return NULL;
+ }
+
+ if (pAllocationCallbacks->onMalloc != NULL) {
+ return pAllocationCallbacks->onMalloc(sz, pAllocationCallbacks->pUserData);
+ }
+
+ /* Try using realloc(). */
+ if (pAllocationCallbacks->onRealloc != NULL) {
+ return pAllocationCallbacks->onRealloc(NULL, sz, pAllocationCallbacks->pUserData);
+ }
+
+ return NULL;
+}
+
+static void* drwav__realloc_from_callbacks(void* p, size_t szNew, size_t szOld, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (pAllocationCallbacks == NULL) {
+ return NULL;
+ }
+
+ if (pAllocationCallbacks->onRealloc != NULL) {
+ return pAllocationCallbacks->onRealloc(p, szNew, pAllocationCallbacks->pUserData);
+ }
+
+ /* Try emulating realloc() in terms of malloc()/free(). */
+ if (pAllocationCallbacks->onMalloc != NULL && pAllocationCallbacks->onFree != NULL) {
+ void* p2;
+
+ p2 = pAllocationCallbacks->onMalloc(szNew, pAllocationCallbacks->pUserData);
+ if (p2 == NULL) {
+ return NULL;
+ }
+
+ if (p != NULL) {
+ DRWAV_COPY_MEMORY(p2, p, szOld);
+ pAllocationCallbacks->onFree(p, pAllocationCallbacks->pUserData);
+ }
+
+ return p2;
+ }
+
+ return NULL;
+}
+
+static void drwav__free_from_callbacks(void* p, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (p == NULL || pAllocationCallbacks == NULL) {
+ return;
+ }
+
+ if (pAllocationCallbacks->onFree != NULL) {
+ pAllocationCallbacks->onFree(p, pAllocationCallbacks->pUserData);
+ }
+}
+
+
+static drwav_allocation_callbacks drwav_copy_allocation_callbacks_or_defaults(const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (pAllocationCallbacks != NULL) {
+ /* Copy. */
+ return *pAllocationCallbacks;
+ } else {
+ /* Defaults. */
+ drwav_allocation_callbacks allocationCallbacks;
+ allocationCallbacks.pUserData = NULL;
+ allocationCallbacks.onMalloc = drwav__malloc_default;
+ allocationCallbacks.onRealloc = drwav__realloc_default;
+ allocationCallbacks.onFree = drwav__free_default;
+ return allocationCallbacks;
+ }
+}
+
+
+static DRWAV_INLINE drwav_bool32 drwav__is_compressed_format_tag(drwav_uint16 formatTag)
+{
+ return
+ formatTag == DR_WAVE_FORMAT_ADPCM ||
+ formatTag == DR_WAVE_FORMAT_DVI_ADPCM;
+}
+
+static unsigned int drwav__chunk_padding_size_riff(drwav_uint64 chunkSize)
+{
+ return (unsigned int)(chunkSize % 2);
+}
+
+static unsigned int drwav__chunk_padding_size_w64(drwav_uint64 chunkSize)
+{
+ return (unsigned int)(chunkSize % 8);
+}
+
+static drwav_uint64 drwav_read_pcm_frames_s16__msadpcm(drwav* pWav, drwav_uint64 samplesToRead, drwav_int16* pBufferOut);
+static drwav_uint64 drwav_read_pcm_frames_s16__ima(drwav* pWav, drwav_uint64 samplesToRead, drwav_int16* pBufferOut);
+static drwav_bool32 drwav_init_write__internal(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount);
+
+static drwav_result drwav__read_chunk_header(drwav_read_proc onRead, void* pUserData, drwav_container container, drwav_uint64* pRunningBytesReadOut, drwav_chunk_header* pHeaderOut)
+{
+ if (container == drwav_container_riff || container == drwav_container_rf64) {
+ drwav_uint8 sizeInBytes[4];
+
+ if (onRead(pUserData, pHeaderOut->id.fourcc, 4) != 4) {
+ return DRWAV_AT_END;
+ }
+
+ if (onRead(pUserData, sizeInBytes, 4) != 4) {
+ return DRWAV_INVALID_FILE;
+ }
+
+ pHeaderOut->sizeInBytes = drwav__bytes_to_u32(sizeInBytes);
+ pHeaderOut->paddingSize = drwav__chunk_padding_size_riff(pHeaderOut->sizeInBytes);
+ *pRunningBytesReadOut += 8;
+ } else {
+ drwav_uint8 sizeInBytes[8];
+
+ if (onRead(pUserData, pHeaderOut->id.guid, 16) != 16) {
+ return DRWAV_AT_END;
+ }
+
+ if (onRead(pUserData, sizeInBytes, 8) != 8) {
+ return DRWAV_INVALID_FILE;
+ }
+
+ pHeaderOut->sizeInBytes = drwav__bytes_to_u64(sizeInBytes) - 24; /* <-- Subtract 24 because w64 includes the size of the header. */
+ pHeaderOut->paddingSize = drwav__chunk_padding_size_w64(pHeaderOut->sizeInBytes);
+ *pRunningBytesReadOut += 24;
+ }
+
+ return DRWAV_SUCCESS;
+}
+
+static drwav_bool32 drwav__seek_forward(drwav_seek_proc onSeek, drwav_uint64 offset, void* pUserData)
+{
+ drwav_uint64 bytesRemainingToSeek = offset;
+ while (bytesRemainingToSeek > 0) {
+ if (bytesRemainingToSeek > 0x7FFFFFFF) {
+ if (!onSeek(pUserData, 0x7FFFFFFF, drwav_seek_origin_current)) {
+ return DRWAV_FALSE;
+ }
+ bytesRemainingToSeek -= 0x7FFFFFFF;
+ } else {
+ if (!onSeek(pUserData, (int)bytesRemainingToSeek, drwav_seek_origin_current)) {
+ return DRWAV_FALSE;
+ }
+ bytesRemainingToSeek = 0;
+ }
+ }
+
+ return DRWAV_TRUE;
+}
+
+static drwav_bool32 drwav__seek_from_start(drwav_seek_proc onSeek, drwav_uint64 offset, void* pUserData)
+{
+ if (offset <= 0x7FFFFFFF) {
+ return onSeek(pUserData, (int)offset, drwav_seek_origin_start);
+ }
+
+ /* Larger than 32-bit seek. */
+ if (!onSeek(pUserData, 0x7FFFFFFF, drwav_seek_origin_start)) {
+ return DRWAV_FALSE;
+ }
+ offset -= 0x7FFFFFFF;
+
+ for (;;) {
+ if (offset <= 0x7FFFFFFF) {
+ return onSeek(pUserData, (int)offset, drwav_seek_origin_current);
+ }
+
+ if (!onSeek(pUserData, 0x7FFFFFFF, drwav_seek_origin_current)) {
+ return DRWAV_FALSE;
+ }
+ offset -= 0x7FFFFFFF;
+ }
+
+ /* Should never get here. */
+ /*return DRWAV_TRUE; */
+}
+
+
+static drwav_bool32 drwav__read_fmt(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, drwav_container container, drwav_uint64* pRunningBytesReadOut, drwav_fmt* fmtOut)
+{
+ drwav_chunk_header header;
+ drwav_uint8 fmt[16];
+
+ if (drwav__read_chunk_header(onRead, pUserData, container, pRunningBytesReadOut, &header) != DRWAV_SUCCESS) {
+ return DRWAV_FALSE;
+ }
+
+
+ /* Skip non-fmt chunks. */
+ while (((container == drwav_container_riff || container == drwav_container_rf64) && !drwav__fourcc_equal(header.id.fourcc, "fmt ")) || (container == drwav_container_w64 && !drwav__guid_equal(header.id.guid, drwavGUID_W64_FMT))) {
+ if (!drwav__seek_forward(onSeek, header.sizeInBytes + header.paddingSize, pUserData)) {
+ return DRWAV_FALSE;
+ }
+ *pRunningBytesReadOut += header.sizeInBytes + header.paddingSize;
+
+ /* Try the next header. */
+ if (drwav__read_chunk_header(onRead, pUserData, container, pRunningBytesReadOut, &header) != DRWAV_SUCCESS) {
+ return DRWAV_FALSE;
+ }
+ }
+
+
+ /* Validation. */
+ if (container == drwav_container_riff || container == drwav_container_rf64) {
+ if (!drwav__fourcc_equal(header.id.fourcc, "fmt ")) {
+ return DRWAV_FALSE;
+ }
+ } else {
+ if (!drwav__guid_equal(header.id.guid, drwavGUID_W64_FMT)) {
+ return DRWAV_FALSE;
+ }
+ }
+
+
+ if (onRead(pUserData, fmt, sizeof(fmt)) != sizeof(fmt)) {
+ return DRWAV_FALSE;
+ }
+ *pRunningBytesReadOut += sizeof(fmt);
+
+ fmtOut->formatTag = drwav__bytes_to_u16(fmt + 0);
+ fmtOut->channels = drwav__bytes_to_u16(fmt + 2);
+ fmtOut->sampleRate = drwav__bytes_to_u32(fmt + 4);
+ fmtOut->avgBytesPerSec = drwav__bytes_to_u32(fmt + 8);
+ fmtOut->blockAlign = drwav__bytes_to_u16(fmt + 12);
+ fmtOut->bitsPerSample = drwav__bytes_to_u16(fmt + 14);
+
+ fmtOut->extendedSize = 0;
+ fmtOut->validBitsPerSample = 0;
+ fmtOut->channelMask = 0;
+ memset(fmtOut->subFormat, 0, sizeof(fmtOut->subFormat));
+
+ if (header.sizeInBytes > 16) {
+ drwav_uint8 fmt_cbSize[2];
+ int bytesReadSoFar = 0;
+
+ if (onRead(pUserData, fmt_cbSize, sizeof(fmt_cbSize)) != sizeof(fmt_cbSize)) {
+ return DRWAV_FALSE; /* Expecting more data. */
+ }
+ *pRunningBytesReadOut += sizeof(fmt_cbSize);
+
+ bytesReadSoFar = 18;
+
+ fmtOut->extendedSize = drwav__bytes_to_u16(fmt_cbSize);
+ if (fmtOut->extendedSize > 0) {
+ /* Simple validation. */
+ if (fmtOut->formatTag == DR_WAVE_FORMAT_EXTENSIBLE) {
+ if (fmtOut->extendedSize != 22) {
+ return DRWAV_FALSE;
+ }
+ }
+
+ if (fmtOut->formatTag == DR_WAVE_FORMAT_EXTENSIBLE) {
+ drwav_uint8 fmtext[22];
+ if (onRead(pUserData, fmtext, fmtOut->extendedSize) != fmtOut->extendedSize) {
+ return DRWAV_FALSE; /* Expecting more data. */
+ }
+
+ fmtOut->validBitsPerSample = drwav__bytes_to_u16(fmtext + 0);
+ fmtOut->channelMask = drwav__bytes_to_u32(fmtext + 2);
+ drwav__bytes_to_guid(fmtext + 6, fmtOut->subFormat);
+ } else {
+ if (!onSeek(pUserData, fmtOut->extendedSize, drwav_seek_origin_current)) {
+ return DRWAV_FALSE;
+ }
+ }
+ *pRunningBytesReadOut += fmtOut->extendedSize;
+
+ bytesReadSoFar += fmtOut->extendedSize;
+ }
+
+ /* Seek past any leftover bytes. For w64 the leftover will be defined based on the chunk size. */
+ if (!onSeek(pUserData, (int)(header.sizeInBytes - bytesReadSoFar), drwav_seek_origin_current)) {
+ return DRWAV_FALSE;
+ }
+ *pRunningBytesReadOut += (header.sizeInBytes - bytesReadSoFar);
+ }
+
+ if (header.paddingSize > 0) {
+ if (!onSeek(pUserData, header.paddingSize, drwav_seek_origin_current)) {
+ return DRWAV_FALSE;
+ }
+ *pRunningBytesReadOut += header.paddingSize;
+ }
+
+ return DRWAV_TRUE;
+}
+
+
+static size_t drwav__on_read(drwav_read_proc onRead, void* pUserData, void* pBufferOut, size_t bytesToRead, drwav_uint64* pCursor)
+{
+ size_t bytesRead;
+
+ DRWAV_ASSERT(onRead != NULL);
+ DRWAV_ASSERT(pCursor != NULL);
+
+ bytesRead = onRead(pUserData, pBufferOut, bytesToRead);
+ *pCursor += bytesRead;
+ return bytesRead;
+}
+
+#if 0
+static drwav_bool32 drwav__on_seek(drwav_seek_proc onSeek, void* pUserData, int offset, drwav_seek_origin origin, drwav_uint64* pCursor)
+{
+ DRWAV_ASSERT(onSeek != NULL);
+ DRWAV_ASSERT(pCursor != NULL);
+
+ if (!onSeek(pUserData, offset, origin)) {
+ return DRWAV_FALSE;
+ }
+
+ if (origin == drwav_seek_origin_start) {
+ *pCursor = offset;
+ } else {
+ *pCursor += offset;
+ }
+
+ return DRWAV_TRUE;
+}
+#endif
+
+
+
+static drwav_uint32 drwav_get_bytes_per_pcm_frame(drwav* pWav)
+{
+ /*
+ The bytes per frame is a bit ambiguous. It can be either be based on the bits per sample, or the block align. The way I'm doing it here
+ is that if the bits per sample is a multiple of 8, use floor(bitsPerSample*channels/8), otherwise fall back to the block align.
+ */
+ if ((pWav->bitsPerSample & 0x7) == 0) {
+ /* Bits per sample is a multiple of 8. */
+ return (pWav->bitsPerSample * pWav->fmt.channels) >> 3;
+ } else {
+ return pWav->fmt.blockAlign;
+ }
+}
+
+DRWAV_API drwav_uint16 drwav_fmt_get_format(const drwav_fmt* pFMT)
+{
+ if (pFMT == NULL) {
+ return 0;
+ }
+
+ if (pFMT->formatTag != DR_WAVE_FORMAT_EXTENSIBLE) {
+ return pFMT->formatTag;
+ } else {
+ return drwav__bytes_to_u16(pFMT->subFormat); /* Only the first two bytes are required. */
+ }
+}
+
+static drwav_bool32 drwav_preinit(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, void* pReadSeekUserData, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (pWav == NULL || onRead == NULL || onSeek == NULL) {
+ return DRWAV_FALSE;
+ }
+
+ DRWAV_ZERO_MEMORY(pWav, sizeof(*pWav));
+ pWav->onRead = onRead;
+ pWav->onSeek = onSeek;
+ pWav->pUserData = pReadSeekUserData;
+ pWav->allocationCallbacks = drwav_copy_allocation_callbacks_or_defaults(pAllocationCallbacks);
+
+ if (pWav->allocationCallbacks.onFree == NULL || (pWav->allocationCallbacks.onMalloc == NULL && pWav->allocationCallbacks.onRealloc == NULL)) {
+ return DRWAV_FALSE; /* Invalid allocation callbacks. */
+ }
+
+ return DRWAV_TRUE;
+}
+
+static drwav_bool32 drwav_init__internal(drwav* pWav, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags)
+{
+ /* This function assumes drwav_preinit() has been called beforehand. */
+
+ drwav_uint64 cursor; /* <-- Keeps track of the byte position so we can seek to specific locations. */
+ drwav_bool32 sequential;
+ drwav_uint8 riff[4];
+ drwav_fmt fmt;
+ unsigned short translatedFormatTag;
+ drwav_bool32 foundDataChunk;
+ drwav_uint64 dataChunkSize = 0; /* <-- Important! Don't explicitly set this to 0 anywhere else. Calculation of the size of the data chunk is performed in different paths depending on the container. */
+ drwav_uint64 sampleCountFromFactChunk = 0; /* Same as dataChunkSize - make sure this is the only place this is initialized to 0. */
+ drwav_uint64 chunkSize;
+
+ cursor = 0;
+ sequential = (flags & DRWAV_SEQUENTIAL) != 0;
+
+ /* The first 4 bytes should be the RIFF identifier. */
+ if (drwav__on_read(pWav->onRead, pWav->pUserData, riff, sizeof(riff), &cursor) != sizeof(riff)) {
+ return DRWAV_FALSE;
+ }
+
+ /*
+ The first 4 bytes can be used to identify the container. For RIFF files it will start with "RIFF" and for
+ w64 it will start with "riff".
+ */
+ if (drwav__fourcc_equal(riff, "RIFF")) {
+ pWav->container = drwav_container_riff;
+ } else if (drwav__fourcc_equal(riff, "riff")) {
+ int i;
+ drwav_uint8 riff2[12];
+
+ pWav->container = drwav_container_w64;
+
+ /* Check the rest of the GUID for validity. */
+ if (drwav__on_read(pWav->onRead, pWav->pUserData, riff2, sizeof(riff2), &cursor) != sizeof(riff2)) {
+ return DRWAV_FALSE;
+ }
+
+ for (i = 0; i < 12; ++i) {
+ if (riff2[i] != drwavGUID_W64_RIFF[i+4]) {
+ return DRWAV_FALSE;
+ }
+ }
+ } else if (drwav__fourcc_equal(riff, "RF64")) {
+ pWav->container = drwav_container_rf64;
+ } else {
+ return DRWAV_FALSE; /* Unknown or unsupported container. */
+ }
+
+
+ if (pWav->container == drwav_container_riff || pWav->container == drwav_container_rf64) {
+ drwav_uint8 chunkSizeBytes[4];
+ drwav_uint8 wave[4];
+
+ /* RIFF/WAVE */
+ if (drwav__on_read(pWav->onRead, pWav->pUserData, chunkSizeBytes, sizeof(chunkSizeBytes), &cursor) != sizeof(chunkSizeBytes)) {
+ return DRWAV_FALSE;
+ }
+
+ if (pWav->container == drwav_container_riff) {
+ if (drwav__bytes_to_u32(chunkSizeBytes) < 36) {
+ return DRWAV_FALSE; /* Chunk size should always be at least 36 bytes. */
+ }
+ } else {
+ if (drwav__bytes_to_u32(chunkSizeBytes) != 0xFFFFFFFF) {
+ return DRWAV_FALSE; /* Chunk size should always be set to -1/0xFFFFFFFF for RF64. The actual size is retrieved later. */
+ }
+ }
+
+ if (drwav__on_read(pWav->onRead, pWav->pUserData, wave, sizeof(wave), &cursor) != sizeof(wave)) {
+ return DRWAV_FALSE;
+ }
+
+ if (!drwav__fourcc_equal(wave, "WAVE")) {
+ return DRWAV_FALSE; /* Expecting "WAVE". */
+ }
+ } else {
+ drwav_uint8 chunkSizeBytes[8];
+ drwav_uint8 wave[16];
+
+ /* W64 */
+ if (drwav__on_read(pWav->onRead, pWav->pUserData, chunkSizeBytes, sizeof(chunkSizeBytes), &cursor) != sizeof(chunkSizeBytes)) {
+ return DRWAV_FALSE;
+ }
+
+ if (drwav__bytes_to_u64(chunkSizeBytes) < 80) {
+ return DRWAV_FALSE;
+ }
+
+ if (drwav__on_read(pWav->onRead, pWav->pUserData, wave, sizeof(wave), &cursor) != sizeof(wave)) {
+ return DRWAV_FALSE;
+ }
+
+ if (!drwav__guid_equal(wave, drwavGUID_W64_WAVE)) {
+ return DRWAV_FALSE;
+ }
+ }
+
+
+ /* For RF64, the "ds64" chunk must come next, before the "fmt " chunk. */
+ if (pWav->container == drwav_container_rf64) {
+ drwav_uint8 sizeBytes[8];
+ drwav_uint64 bytesRemainingInChunk;
+ drwav_chunk_header header;
+ drwav_result result = drwav__read_chunk_header(pWav->onRead, pWav->pUserData, pWav->container, &cursor, &header);
+ if (result != DRWAV_SUCCESS) {
+ return DRWAV_FALSE;
+ }
+
+ if (!drwav__fourcc_equal(header.id.fourcc, "ds64")) {
+ return DRWAV_FALSE; /* Expecting "ds64". */
+ }
+
+ bytesRemainingInChunk = header.sizeInBytes + header.paddingSize;
+
+ /* We don't care about the size of the RIFF chunk - skip it. */
+ if (!drwav__seek_forward(pWav->onSeek, 8, pWav->pUserData)) {
+ return DRWAV_FALSE;
+ }
+ bytesRemainingInChunk -= 8;
+ cursor += 8;
+
+
+ /* Next 8 bytes is the size of the "data" chunk. */
+ if (drwav__on_read(pWav->onRead, pWav->pUserData, sizeBytes, sizeof(sizeBytes), &cursor) != sizeof(sizeBytes)) {
+ return DRWAV_FALSE;
+ }
+ bytesRemainingInChunk -= 8;
+ dataChunkSize = drwav__bytes_to_u64(sizeBytes);
+
+
+ /* Next 8 bytes is the same count which we would usually derived from the FACT chunk if it was available. */
+ if (drwav__on_read(pWav->onRead, pWav->pUserData, sizeBytes, sizeof(sizeBytes), &cursor) != sizeof(sizeBytes)) {
+ return DRWAV_FALSE;
+ }
+ bytesRemainingInChunk -= 8;
+ sampleCountFromFactChunk = drwav__bytes_to_u64(sizeBytes);
+
+
+ /* Skip over everything else. */
+ if (!drwav__seek_forward(pWav->onSeek, bytesRemainingInChunk, pWav->pUserData)) {
+ return DRWAV_FALSE;
+ }
+ cursor += bytesRemainingInChunk;
+ }
+
+
+ /* The next bytes should be the "fmt " chunk. */
+ if (!drwav__read_fmt(pWav->onRead, pWav->onSeek, pWav->pUserData, pWav->container, &cursor, &fmt)) {
+ return DRWAV_FALSE; /* Failed to read the "fmt " chunk. */
+ }
+
+ /* Basic validation. */
+ if ((fmt.sampleRate == 0 || fmt.sampleRate > DRWAV_MAX_SAMPLE_RATE) ||
+ (fmt.channels == 0 || fmt.channels > DRWAV_MAX_CHANNELS) ||
+ (fmt.bitsPerSample == 0 || fmt.bitsPerSample > DRWAV_MAX_BITS_PER_SAMPLE) ||
+ fmt.blockAlign == 0) {
+ return DRWAV_FALSE; /* Probably an invalid WAV file. */
+ }
+
+
+ /* Translate the internal format. */
+ translatedFormatTag = fmt.formatTag;
+ if (translatedFormatTag == DR_WAVE_FORMAT_EXTENSIBLE) {
+ translatedFormatTag = drwav__bytes_to_u16(fmt.subFormat + 0);
+ }
+
+
+ /*
+ We need to enumerate over each chunk for two reasons:
+ 1) The "data" chunk may not be the next one
+ 2) We may want to report each chunk back to the client
+
+ In order to correctly report each chunk back to the client we will need to keep looping until the end of the file.
+ */
+ foundDataChunk = DRWAV_FALSE;
+
+ /* The next chunk we care about is the "data" chunk. This is not necessarily the next chunk so we'll need to loop. */
+ for (;;)
+ {
+ drwav_chunk_header header;
+ drwav_result result = drwav__read_chunk_header(pWav->onRead, pWav->pUserData, pWav->container, &cursor, &header);
+ if (result != DRWAV_SUCCESS) {
+ if (!foundDataChunk) {
+ return DRWAV_FALSE;
+ } else {
+ break; /* Probably at the end of the file. Get out of the loop. */
+ }
+ }
+
+ /* Tell the client about this chunk. */
+ if (!sequential && onChunk != NULL) {
+ drwav_uint64 callbackBytesRead = onChunk(pChunkUserData, pWav->onRead, pWav->onSeek, pWav->pUserData, &header, pWav->container, &fmt);
+
+ /*
+ dr_wav may need to read the contents of the chunk, so we now need to seek back to the position before
+ we called the callback.
+ */
+ if (callbackBytesRead > 0) {
+ if (!drwav__seek_from_start(pWav->onSeek, cursor, pWav->pUserData)) {
+ return DRWAV_FALSE;
+ }
+ }
+ }
+
+
+ if (!foundDataChunk) {
+ pWav->dataChunkDataPos = cursor;
+ }
+
+ chunkSize = header.sizeInBytes;
+ if (pWav->container == drwav_container_riff || pWav->container == drwav_container_rf64) {
+ if (drwav__fourcc_equal(header.id.fourcc, "data")) {
+ foundDataChunk = DRWAV_TRUE;
+ if (pWav->container != drwav_container_rf64) { /* The data chunk size for RF64 will always be set to 0xFFFFFFFF here. It was set to it's true value earlier. */
+ dataChunkSize = chunkSize;
+ }
+ }
+ } else {
+ if (drwav__guid_equal(header.id.guid, drwavGUID_W64_DATA)) {
+ foundDataChunk = DRWAV_TRUE;
+ dataChunkSize = chunkSize;
+ }
+ }
+
+ /*
+ If at this point we have found the data chunk and we're running in sequential mode, we need to break out of this loop. The reason for
+ this is that we would otherwise require a backwards seek which sequential mode forbids.
+ */
+ if (foundDataChunk && sequential) {
+ break;
+ }
+
+ /* Optional. Get the total sample count from the FACT chunk. This is useful for compressed formats. */
+ if (pWav->container == drwav_container_riff) {
+ if (drwav__fourcc_equal(header.id.fourcc, "fact")) {
+ drwav_uint32 sampleCount;
+ if (drwav__on_read(pWav->onRead, pWav->pUserData, &sampleCount, 4, &cursor) != 4) {
+ return DRWAV_FALSE;
+ }
+ chunkSize -= 4;
+
+ if (!foundDataChunk) {
+ pWav->dataChunkDataPos = cursor;
+ }
+
+ /*
+ The sample count in the "fact" chunk is either unreliable, or I'm not understanding it properly. For now I am only enabling this
+ for Microsoft ADPCM formats.
+ */
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) {
+ sampleCountFromFactChunk = sampleCount;
+ } else {
+ sampleCountFromFactChunk = 0;
+ }
+ }
+ } else if (pWav->container == drwav_container_w64) {
+ if (drwav__guid_equal(header.id.guid, drwavGUID_W64_FACT)) {
+ if (drwav__on_read(pWav->onRead, pWav->pUserData, &sampleCountFromFactChunk, 8, &cursor) != 8) {
+ return DRWAV_FALSE;
+ }
+ chunkSize -= 8;
+
+ if (!foundDataChunk) {
+ pWav->dataChunkDataPos = cursor;
+ }
+ }
+ } else if (pWav->container == drwav_container_rf64) {
+ /* We retrieved the sample count from the ds64 chunk earlier so no need to do that here. */
+ }
+
+ /* "smpl" chunk. */
+ if (pWav->container == drwav_container_riff || pWav->container == drwav_container_rf64) {
+ if (drwav__fourcc_equal(header.id.fourcc, "smpl")) {
+ drwav_uint8 smplHeaderData[36]; /* 36 = size of the smpl header section, not including the loop data. */
+ if (chunkSize >= sizeof(smplHeaderData)) {
+ drwav_uint64 bytesJustRead = drwav__on_read(pWav->onRead, pWav->pUserData, smplHeaderData, sizeof(smplHeaderData), &cursor);
+ chunkSize -= bytesJustRead;
+
+ if (bytesJustRead == sizeof(smplHeaderData)) {
+ drwav_uint32 iLoop;
+
+ pWav->smpl.manufacturer = drwav__bytes_to_u32(smplHeaderData+0);
+ pWav->smpl.product = drwav__bytes_to_u32(smplHeaderData+4);
+ pWav->smpl.samplePeriod = drwav__bytes_to_u32(smplHeaderData+8);
+ pWav->smpl.midiUnityNotes = drwav__bytes_to_u32(smplHeaderData+12);
+ pWav->smpl.midiPitchFraction = drwav__bytes_to_u32(smplHeaderData+16);
+ pWav->smpl.smpteFormat = drwav__bytes_to_u32(smplHeaderData+20);
+ pWav->smpl.smpteOffset = drwav__bytes_to_u32(smplHeaderData+24);
+ pWav->smpl.numSampleLoops = drwav__bytes_to_u32(smplHeaderData+28);
+ pWav->smpl.samplerData = drwav__bytes_to_u32(smplHeaderData+32);
+
+ for (iLoop = 0; iLoop < pWav->smpl.numSampleLoops && iLoop < drwav_countof(pWav->smpl.loops); ++iLoop) {
+ drwav_uint8 smplLoopData[24]; /* 24 = size of a loop section in the smpl chunk. */
+ bytesJustRead = drwav__on_read(pWav->onRead, pWav->pUserData, smplLoopData, sizeof(smplLoopData), &cursor);
+ chunkSize -= bytesJustRead;
+
+ if (bytesJustRead == sizeof(smplLoopData)) {
+ pWav->smpl.loops[iLoop].cuePointId = drwav__bytes_to_u32(smplLoopData+0);
+ pWav->smpl.loops[iLoop].type = drwav__bytes_to_u32(smplLoopData+4);
+ pWav->smpl.loops[iLoop].start = drwav__bytes_to_u32(smplLoopData+8);
+ pWav->smpl.loops[iLoop].end = drwav__bytes_to_u32(smplLoopData+12);
+ pWav->smpl.loops[iLoop].fraction = drwav__bytes_to_u32(smplLoopData+16);
+ pWav->smpl.loops[iLoop].playCount = drwav__bytes_to_u32(smplLoopData+20);
+ } else {
+ break; /* Break from the smpl loop for loop. */
+ }
+ }
+ }
+ } else {
+ /* Looks like invalid data. Ignore the chunk. */
+ }
+ }
+ } else {
+ if (drwav__guid_equal(header.id.guid, drwavGUID_W64_SMPL)) {
+ /*
+ This path will be hit when a W64 WAV file contains a smpl chunk. I don't have a sample file to test this path, so a contribution
+ is welcome to add support for this.
+ */
+ }
+ }
+
+ /* Make sure we seek past the padding. */
+ chunkSize += header.paddingSize;
+ if (!drwav__seek_forward(pWav->onSeek, chunkSize, pWav->pUserData)) {
+ break;
+ }
+ cursor += chunkSize;
+
+ if (!foundDataChunk) {
+ pWav->dataChunkDataPos = cursor;
+ }
+ }
+
+ /* If we haven't found a data chunk, return an error. */
+ if (!foundDataChunk) {
+ return DRWAV_FALSE;
+ }
+
+ /* We may have moved passed the data chunk. If so we need to move back. If running in sequential mode we can assume we are already sitting on the data chunk. */
+ if (!sequential) {
+ if (!drwav__seek_from_start(pWav->onSeek, pWav->dataChunkDataPos, pWav->pUserData)) {
+ return DRWAV_FALSE;
+ }
+ cursor = pWav->dataChunkDataPos;
+ }
+
+
+ /* At this point we should be sitting on the first byte of the raw audio data. */
+
+ pWav->fmt = fmt;
+ pWav->sampleRate = fmt.sampleRate;
+ pWav->channels = fmt.channels;
+ pWav->bitsPerSample = fmt.bitsPerSample;
+ pWav->bytesRemaining = dataChunkSize;
+ pWav->translatedFormatTag = translatedFormatTag;
+ pWav->dataChunkDataSize = dataChunkSize;
+
+ if (sampleCountFromFactChunk != 0) {
+ pWav->totalPCMFrameCount = sampleCountFromFactChunk;
+ } else {
+ pWav->totalPCMFrameCount = dataChunkSize / drwav_get_bytes_per_pcm_frame(pWav);
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) {
+ drwav_uint64 totalBlockHeaderSizeInBytes;
+ drwav_uint64 blockCount = dataChunkSize / fmt.blockAlign;
+
+ /* Make sure any trailing partial block is accounted for. */
+ if ((blockCount * fmt.blockAlign) < dataChunkSize) {
+ blockCount += 1;
+ }
+
+ /* We decode two samples per byte. There will be blockCount headers in the data chunk. This is enough to know how to calculate the total PCM frame count. */
+ totalBlockHeaderSizeInBytes = blockCount * (6*fmt.channels);
+ pWav->totalPCMFrameCount = ((dataChunkSize - totalBlockHeaderSizeInBytes) * 2) / fmt.channels;
+ }
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) {
+ drwav_uint64 totalBlockHeaderSizeInBytes;
+ drwav_uint64 blockCount = dataChunkSize / fmt.blockAlign;
+
+ /* Make sure any trailing partial block is accounted for. */
+ if ((blockCount * fmt.blockAlign) < dataChunkSize) {
+ blockCount += 1;
+ }
+
+ /* We decode two samples per byte. There will be blockCount headers in the data chunk. This is enough to know how to calculate the total PCM frame count. */
+ totalBlockHeaderSizeInBytes = blockCount * (4*fmt.channels);
+ pWav->totalPCMFrameCount = ((dataChunkSize - totalBlockHeaderSizeInBytes) * 2) / fmt.channels;
+
+ /* The header includes a decoded sample for each channel which acts as the initial predictor sample. */
+ pWav->totalPCMFrameCount += blockCount;
+ }
+ }
+
+ /* Some formats only support a certain number of channels. */
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM || pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) {
+ if (pWav->channels > 2) {
+ return DRWAV_FALSE;
+ }
+ }
+
+#ifdef DR_WAV_LIBSNDFILE_COMPAT
+ /*
+ I use libsndfile as a benchmark for testing, however in the version I'm using (from the Windows installer on the libsndfile website),
+ it appears the total sample count libsndfile uses for MS-ADPCM is incorrect. It would seem they are computing the total sample count
+ from the number of blocks, however this results in the inclusion of extra silent samples at the end of the last block. The correct
+ way to know the total sample count is to inspect the "fact" chunk, which should always be present for compressed formats, and should
+ always include the sample count. This little block of code below is only used to emulate the libsndfile logic so I can properly run my
+ correctness tests against libsndfile, and is disabled by default.
+ */
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) {
+ drwav_uint64 blockCount = dataChunkSize / fmt.blockAlign;
+ pWav->totalPCMFrameCount = (((blockCount * (fmt.blockAlign - (6*pWav->channels))) * 2)) / fmt.channels; /* x2 because two samples per byte. */
+ }
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) {
+ drwav_uint64 blockCount = dataChunkSize / fmt.blockAlign;
+ pWav->totalPCMFrameCount = (((blockCount * (fmt.blockAlign - (4*pWav->channels))) * 2) + (blockCount * pWav->channels)) / fmt.channels;
+ }
+#endif
+
+ return DRWAV_TRUE;
+}
+
+DRWAV_API drwav_bool32 drwav_init(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ return drwav_init_ex(pWav, onRead, onSeek, NULL, pUserData, NULL, 0, pAllocationCallbacks);
+}
+
+DRWAV_API drwav_bool32 drwav_init_ex(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, drwav_chunk_proc onChunk, void* pReadSeekUserData, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (!drwav_preinit(pWav, onRead, onSeek, pReadSeekUserData, pAllocationCallbacks)) {
+ return DRWAV_FALSE;
+ }
+
+ return drwav_init__internal(pWav, onChunk, pChunkUserData, flags);
+}
+
+
+static drwav_uint32 drwav__riff_chunk_size_riff(drwav_uint64 dataChunkSize)
+{
+ drwav_uint64 chunkSize = 4 + 24 + dataChunkSize + drwav__chunk_padding_size_riff(dataChunkSize); /* 4 = "WAVE". 24 = "fmt " chunk. */
+ if (chunkSize > 0xFFFFFFFFUL) {
+ chunkSize = 0xFFFFFFFFUL;
+ }
+
+ return (drwav_uint32)chunkSize; /* Safe cast due to the clamp above. */
+}
+
+static drwav_uint32 drwav__data_chunk_size_riff(drwav_uint64 dataChunkSize)
+{
+ if (dataChunkSize <= 0xFFFFFFFFUL) {
+ return (drwav_uint32)dataChunkSize;
+ } else {
+ return 0xFFFFFFFFUL;
+ }
+}
+
+static drwav_uint64 drwav__riff_chunk_size_w64(drwav_uint64 dataChunkSize)
+{
+ drwav_uint64 dataSubchunkPaddingSize = drwav__chunk_padding_size_w64(dataChunkSize);
+
+ return 80 + 24 + dataChunkSize + dataSubchunkPaddingSize; /* +24 because W64 includes the size of the GUID and size fields. */
+}
+
+static drwav_uint64 drwav__data_chunk_size_w64(drwav_uint64 dataChunkSize)
+{
+ return 24 + dataChunkSize; /* +24 because W64 includes the size of the GUID and size fields. */
+}
+
+static drwav_uint64 drwav__riff_chunk_size_rf64(drwav_uint64 dataChunkSize)
+{
+ drwav_uint64 chunkSize = 4 + 36 + 24 + dataChunkSize + drwav__chunk_padding_size_riff(dataChunkSize); /* 4 = "WAVE". 36 = "ds64" chunk. 24 = "fmt " chunk. */
+ if (chunkSize > 0xFFFFFFFFUL) {
+ chunkSize = 0xFFFFFFFFUL;
+ }
+
+ return chunkSize;
+}
+
+static drwav_uint64 drwav__data_chunk_size_rf64(drwav_uint64 dataChunkSize)
+{
+ return dataChunkSize;
+}
+
+
+static size_t drwav__write(drwav* pWav, const void* pData, size_t dataSize)
+{
+ DRWAV_ASSERT(pWav != NULL);
+ DRWAV_ASSERT(pWav->onWrite != NULL);
+
+ /* Generic write. Assumes no byte reordering required. */
+ return pWav->onWrite(pWav->pUserData, pData, dataSize);
+}
+
+static size_t drwav__write_u16ne_to_le(drwav* pWav, drwav_uint16 value)
+{
+ DRWAV_ASSERT(pWav != NULL);
+ DRWAV_ASSERT(pWav->onWrite != NULL);
+
+ if (!drwav__is_little_endian()) {
+ value = drwav__bswap16(value);
+ }
+
+ return drwav__write(pWav, &value, 2);
+}
+
+static size_t drwav__write_u32ne_to_le(drwav* pWav, drwav_uint32 value)
+{
+ DRWAV_ASSERT(pWav != NULL);
+ DRWAV_ASSERT(pWav->onWrite != NULL);
+
+ if (!drwav__is_little_endian()) {
+ value = drwav__bswap32(value);
+ }
+
+ return drwav__write(pWav, &value, 4);
+}
+
+static size_t drwav__write_u64ne_to_le(drwav* pWav, drwav_uint64 value)
+{
+ DRWAV_ASSERT(pWav != NULL);
+ DRWAV_ASSERT(pWav->onWrite != NULL);
+
+ if (!drwav__is_little_endian()) {
+ value = drwav__bswap64(value);
+ }
+
+ return drwav__write(pWav, &value, 8);
+}
+
+
+static drwav_bool32 drwav_preinit_write(drwav* pWav, const drwav_data_format* pFormat, drwav_bool32 isSequential, drwav_write_proc onWrite, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (pWav == NULL || onWrite == NULL) {
+ return DRWAV_FALSE;
+ }
+
+ if (!isSequential && onSeek == NULL) {
+ return DRWAV_FALSE; /* <-- onSeek is required when in non-sequential mode. */
+ }
+
+ /* Not currently supporting compressed formats. Will need to add support for the "fact" chunk before we enable this. */
+ if (pFormat->format == DR_WAVE_FORMAT_EXTENSIBLE) {
+ return DRWAV_FALSE;
+ }
+ if (pFormat->format == DR_WAVE_FORMAT_ADPCM || pFormat->format == DR_WAVE_FORMAT_DVI_ADPCM) {
+ return DRWAV_FALSE;
+ }
+
+ DRWAV_ZERO_MEMORY(pWav, sizeof(*pWav));
+ pWav->onWrite = onWrite;
+ pWav->onSeek = onSeek;
+ pWav->pUserData = pUserData;
+ pWav->allocationCallbacks = drwav_copy_allocation_callbacks_or_defaults(pAllocationCallbacks);
+
+ if (pWav->allocationCallbacks.onFree == NULL || (pWav->allocationCallbacks.onMalloc == NULL && pWav->allocationCallbacks.onRealloc == NULL)) {
+ return DRWAV_FALSE; /* Invalid allocation callbacks. */
+ }
+
+ pWav->fmt.formatTag = (drwav_uint16)pFormat->format;
+ pWav->fmt.channels = (drwav_uint16)pFormat->channels;
+ pWav->fmt.sampleRate = pFormat->sampleRate;
+ pWav->fmt.avgBytesPerSec = (drwav_uint32)((pFormat->bitsPerSample * pFormat->sampleRate * pFormat->channels) / 8);
+ pWav->fmt.blockAlign = (drwav_uint16)((pFormat->channels * pFormat->bitsPerSample) / 8);
+ pWav->fmt.bitsPerSample = (drwav_uint16)pFormat->bitsPerSample;
+ pWav->fmt.extendedSize = 0;
+ pWav->isSequentialWrite = isSequential;
+
+ return DRWAV_TRUE;
+}
+
+static drwav_bool32 drwav_init_write__internal(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount)
+{
+ /* The function assumes drwav_preinit_write() was called beforehand. */
+
+ size_t runningPos = 0;
+ drwav_uint64 initialDataChunkSize = 0;
+ drwav_uint64 chunkSizeFMT;
+
+ /*
+ The initial values for the "RIFF" and "data" chunks depends on whether or not we are initializing in sequential mode or not. In
+ sequential mode we set this to its final values straight away since they can be calculated from the total sample count. In non-
+ sequential mode we initialize it all to zero and fill it out in drwav_uninit() using a backwards seek.
+ */
+ if (pWav->isSequentialWrite) {
+ initialDataChunkSize = (totalSampleCount * pWav->fmt.bitsPerSample) / 8;
+
+ /*
+ The RIFF container has a limit on the number of samples. drwav is not allowing this. There's no practical limits for Wave64
+ so for the sake of simplicity I'm not doing any validation for that.
+ */
+ if (pFormat->container == drwav_container_riff) {
+ if (initialDataChunkSize > (0xFFFFFFFFUL - 36)) {
+ return DRWAV_FALSE; /* Not enough room to store every sample. */
+ }
+ }
+ }
+
+ pWav->dataChunkDataSizeTargetWrite = initialDataChunkSize;
+
+
+ /* "RIFF" chunk. */
+ if (pFormat->container == drwav_container_riff) {
+ drwav_uint32 chunkSizeRIFF = 28 + (drwav_uint32)initialDataChunkSize; /* +28 = "WAVE" + [sizeof "fmt " chunk] */
+ runningPos += drwav__write(pWav, "RIFF", 4);
+ runningPos += drwav__write_u32ne_to_le(pWav, chunkSizeRIFF);
+ runningPos += drwav__write(pWav, "WAVE", 4);
+ } else if (pFormat->container == drwav_container_w64) {
+ drwav_uint64 chunkSizeRIFF = 80 + 24 + initialDataChunkSize; /* +24 because W64 includes the size of the GUID and size fields. */
+ runningPos += drwav__write(pWav, drwavGUID_W64_RIFF, 16);
+ runningPos += drwav__write_u64ne_to_le(pWav, chunkSizeRIFF);
+ runningPos += drwav__write(pWav, drwavGUID_W64_WAVE, 16);
+ } else if (pFormat->container == drwav_container_rf64) {
+ runningPos += drwav__write(pWav, "RF64", 4);
+ runningPos += drwav__write_u32ne_to_le(pWav, 0xFFFFFFFF); /* Always 0xFFFFFFFF for RF64. Set to a proper value in the "ds64" chunk. */
+ runningPos += drwav__write(pWav, "WAVE", 4);
+ }
+
+
+ /* "ds64" chunk (RF64 only). */
+ if (pFormat->container == drwav_container_rf64) {
+ drwav_uint32 initialds64ChunkSize = 28; /* 28 = [Size of RIFF (8 bytes)] + [Size of DATA (8 bytes)] + [Sample Count (8 bytes)] + [Table Length (4 bytes)]. Table length always set to 0. */
+ drwav_uint64 initialRiffChunkSize = 8 + initialds64ChunkSize + initialDataChunkSize; /* +8 for the ds64 header. */
+
+ runningPos += drwav__write(pWav, "ds64", 4);
+ runningPos += drwav__write_u32ne_to_le(pWav, initialds64ChunkSize); /* Size of ds64. */
+ runningPos += drwav__write_u64ne_to_le(pWav, initialRiffChunkSize); /* Size of RIFF. Set to true value at the end. */
+ runningPos += drwav__write_u64ne_to_le(pWav, initialDataChunkSize); /* Size of DATA. Set to true value at the end. */
+ runningPos += drwav__write_u64ne_to_le(pWav, totalSampleCount); /* Sample count. */
+ runningPos += drwav__write_u32ne_to_le(pWav, 0); /* Table length. Always set to zero in our case since we're not doing any other chunks than "DATA". */
+ }
+
+
+ /* "fmt " chunk. */
+ if (pFormat->container == drwav_container_riff || pFormat->container == drwav_container_rf64) {
+ chunkSizeFMT = 16;
+ runningPos += drwav__write(pWav, "fmt ", 4);
+ runningPos += drwav__write_u32ne_to_le(pWav, (drwav_uint32)chunkSizeFMT);
+ } else if (pFormat->container == drwav_container_w64) {
+ chunkSizeFMT = 40;
+ runningPos += drwav__write(pWav, drwavGUID_W64_FMT, 16);
+ runningPos += drwav__write_u64ne_to_le(pWav, chunkSizeFMT);
+ }
+
+ runningPos += drwav__write_u16ne_to_le(pWav, pWav->fmt.formatTag);
+ runningPos += drwav__write_u16ne_to_le(pWav, pWav->fmt.channels);
+ runningPos += drwav__write_u32ne_to_le(pWav, pWav->fmt.sampleRate);
+ runningPos += drwav__write_u32ne_to_le(pWav, pWav->fmt.avgBytesPerSec);
+ runningPos += drwav__write_u16ne_to_le(pWav, pWav->fmt.blockAlign);
+ runningPos += drwav__write_u16ne_to_le(pWav, pWav->fmt.bitsPerSample);
+
+ pWav->dataChunkDataPos = runningPos;
+
+ /* "data" chunk. */
+ if (pFormat->container == drwav_container_riff) {
+ drwav_uint32 chunkSizeDATA = (drwav_uint32)initialDataChunkSize;
+ runningPos += drwav__write(pWav, "data", 4);
+ runningPos += drwav__write_u32ne_to_le(pWav, chunkSizeDATA);
+ } else if (pFormat->container == drwav_container_w64) {
+ drwav_uint64 chunkSizeDATA = 24 + initialDataChunkSize; /* +24 because W64 includes the size of the GUID and size fields. */
+ runningPos += drwav__write(pWav, drwavGUID_W64_DATA, 16);
+ runningPos += drwav__write_u64ne_to_le(pWav, chunkSizeDATA);
+ } else if (pFormat->container == drwav_container_rf64) {
+ runningPos += drwav__write(pWav, "data", 4);
+ runningPos += drwav__write_u32ne_to_le(pWav, 0xFFFFFFFF); /* Always set to 0xFFFFFFFF for RF64. The true size of the data chunk is specified in the ds64 chunk. */
+ }
+
+ /*
+ The runningPos variable is incremented in the section above but is left unused which is causing some static analysis tools to detect it
+ as a dead store. I'm leaving this as-is for safety just in case I want to expand this function later to include other tags and want to
+ keep track of the running position for whatever reason. The line below should silence the static analysis tools.
+ */
+ (void)runningPos;
+
+ /* Set some properties for the client's convenience. */
+ pWav->container = pFormat->container;
+ pWav->channels = (drwav_uint16)pFormat->channels;
+ pWav->sampleRate = pFormat->sampleRate;
+ pWav->bitsPerSample = (drwav_uint16)pFormat->bitsPerSample;
+ pWav->translatedFormatTag = (drwav_uint16)pFormat->format;
+
+ return DRWAV_TRUE;
+}
+
+
+DRWAV_API drwav_bool32 drwav_init_write(drwav* pWav, const drwav_data_format* pFormat, drwav_write_proc onWrite, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (!drwav_preinit_write(pWav, pFormat, DRWAV_FALSE, onWrite, onSeek, pUserData, pAllocationCallbacks)) {
+ return DRWAV_FALSE;
+ }
+
+ return drwav_init_write__internal(pWav, pFormat, 0); /* DRWAV_FALSE = Not Sequential */
+}
+
+DRWAV_API drwav_bool32 drwav_init_write_sequential(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_write_proc onWrite, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (!drwav_preinit_write(pWav, pFormat, DRWAV_TRUE, onWrite, NULL, pUserData, pAllocationCallbacks)) {
+ return DRWAV_FALSE;
+ }
+
+ return drwav_init_write__internal(pWav, pFormat, totalSampleCount); /* DRWAV_TRUE = Sequential */
+}
+
+DRWAV_API drwav_bool32 drwav_init_write_sequential_pcm_frames(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, drwav_write_proc onWrite, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (pFormat == NULL) {
+ return DRWAV_FALSE;
+ }
+
+ return drwav_init_write_sequential(pWav, pFormat, totalPCMFrameCount*pFormat->channels, onWrite, pUserData, pAllocationCallbacks);
+}
+
+DRWAV_API drwav_uint64 drwav_target_write_size_bytes(const drwav_data_format* pFormat, drwav_uint64 totalSampleCount)
+{
+ /* Casting totalSampleCount to drwav_int64 for VC6 compatibility. No issues in practice because nobody is going to exhaust the whole 63 bits. */
+ drwav_uint64 targetDataSizeBytes = (drwav_uint64)((drwav_int64)totalSampleCount * pFormat->channels * pFormat->bitsPerSample/8.0);
+ drwav_uint64 riffChunkSizeBytes;
+ drwav_uint64 fileSizeBytes = 0;
+
+ if (pFormat->container == drwav_container_riff) {
+ riffChunkSizeBytes = drwav__riff_chunk_size_riff(targetDataSizeBytes);
+ fileSizeBytes = (8 + riffChunkSizeBytes); /* +8 because WAV doesn't include the size of the ChunkID and ChunkSize fields. */
+ } else if (pFormat->container == drwav_container_w64) {
+ riffChunkSizeBytes = drwav__riff_chunk_size_w64(targetDataSizeBytes);
+ fileSizeBytes = riffChunkSizeBytes;
+ } else if (pFormat->container == drwav_container_rf64) {
+ riffChunkSizeBytes = drwav__riff_chunk_size_rf64(targetDataSizeBytes);
+ fileSizeBytes = (8 + riffChunkSizeBytes); /* +8 because WAV doesn't include the size of the ChunkID and ChunkSize fields. */
+ }
+
+ return fileSizeBytes;
+}
+
+
+#ifndef DR_WAV_NO_STDIO
+
+/* drwav_result_from_errno() is only used for fopen() and wfopen() so putting it inside DR_WAV_NO_STDIO for now. If something else needs this later we can move it out. */
+#include <errno.h>
+static drwav_result drwav_result_from_errno(int e)
+{
+ switch (e)
+ {
+ case 0: return DRWAV_SUCCESS;
+ #ifdef EPERM
+ case EPERM: return DRWAV_INVALID_OPERATION;
+ #endif
+ #ifdef ENOENT
+ case ENOENT: return DRWAV_DOES_NOT_EXIST;
+ #endif
+ #ifdef ESRCH
+ case ESRCH: return DRWAV_DOES_NOT_EXIST;
+ #endif
+ #ifdef EINTR
+ case EINTR: return DRWAV_INTERRUPT;
+ #endif
+ #ifdef EIO
+ case EIO: return DRWAV_IO_ERROR;
+ #endif
+ #ifdef ENXIO
+ case ENXIO: return DRWAV_DOES_NOT_EXIST;
+ #endif
+ #ifdef E2BIG
+ case E2BIG: return DRWAV_INVALID_ARGS;
+ #endif
+ #ifdef ENOEXEC
+ case ENOEXEC: return DRWAV_INVALID_FILE;
+ #endif
+ #ifdef EBADF
+ case EBADF: return DRWAV_INVALID_FILE;
+ #endif
+ #ifdef ECHILD
+ case ECHILD: return DRWAV_ERROR;
+ #endif
+ #ifdef EAGAIN
+ case EAGAIN: return DRWAV_UNAVAILABLE;
+ #endif
+ #ifdef ENOMEM
+ case ENOMEM: return DRWAV_OUT_OF_MEMORY;
+ #endif
+ #ifdef EACCES
+ case EACCES: return DRWAV_ACCESS_DENIED;
+ #endif
+ #ifdef EFAULT
+ case EFAULT: return DRWAV_BAD_ADDRESS;
+ #endif
+ #ifdef ENOTBLK
+ case ENOTBLK: return DRWAV_ERROR;
+ #endif
+ #ifdef EBUSY
+ case EBUSY: return DRWAV_BUSY;
+ #endif
+ #ifdef EEXIST
+ case EEXIST: return DRWAV_ALREADY_EXISTS;
+ #endif
+ #ifdef EXDEV
+ case EXDEV: return DRWAV_ERROR;
+ #endif
+ #ifdef ENODEV
+ case ENODEV: return DRWAV_DOES_NOT_EXIST;
+ #endif
+ #ifdef ENOTDIR
+ case ENOTDIR: return DRWAV_NOT_DIRECTORY;
+ #endif
+ #ifdef EISDIR
+ case EISDIR: return DRWAV_IS_DIRECTORY;
+ #endif
+ #ifdef EINVAL
+ case EINVAL: return DRWAV_INVALID_ARGS;
+ #endif
+ #ifdef ENFILE
+ case ENFILE: return DRWAV_TOO_MANY_OPEN_FILES;
+ #endif
+ #ifdef EMFILE
+ case EMFILE: return DRWAV_TOO_MANY_OPEN_FILES;
+ #endif
+ #ifdef ENOTTY
+ case ENOTTY: return DRWAV_INVALID_OPERATION;
+ #endif
+ #ifdef ETXTBSY
+ case ETXTBSY: return DRWAV_BUSY;
+ #endif
+ #ifdef EFBIG
+ case EFBIG: return DRWAV_TOO_BIG;
+ #endif
+ #ifdef ENOSPC
+ case ENOSPC: return DRWAV_NO_SPACE;
+ #endif
+ #ifdef ESPIPE
+ case ESPIPE: return DRWAV_BAD_SEEK;
+ #endif
+ #ifdef EROFS
+ case EROFS: return DRWAV_ACCESS_DENIED;
+ #endif
+ #ifdef EMLINK
+ case EMLINK: return DRWAV_TOO_MANY_LINKS;
+ #endif
+ #ifdef EPIPE
+ case EPIPE: return DRWAV_BAD_PIPE;
+ #endif
+ #ifdef EDOM
+ case EDOM: return DRWAV_OUT_OF_RANGE;
+ #endif
+ #ifdef ERANGE
+ case ERANGE: return DRWAV_OUT_OF_RANGE;
+ #endif
+ #ifdef EDEADLK
+ case EDEADLK: return DRWAV_DEADLOCK;
+ #endif
+ #ifdef ENAMETOOLONG
+ case ENAMETOOLONG: return DRWAV_PATH_TOO_LONG;
+ #endif
+ #ifdef ENOLCK
+ case ENOLCK: return DRWAV_ERROR;
+ #endif
+ #ifdef ENOSYS
+ case ENOSYS: return DRWAV_NOT_IMPLEMENTED;
+ #endif
+ #ifdef ENOTEMPTY
+ case ENOTEMPTY: return DRWAV_DIRECTORY_NOT_EMPTY;
+ #endif
+ #ifdef ELOOP
+ case ELOOP: return DRWAV_TOO_MANY_LINKS;
+ #endif
+ #ifdef ENOMSG
+ case ENOMSG: return DRWAV_NO_MESSAGE;
+ #endif
+ #ifdef EIDRM
+ case EIDRM: return DRWAV_ERROR;
+ #endif
+ #ifdef ECHRNG
+ case ECHRNG: return DRWAV_ERROR;
+ #endif
+ #ifdef EL2NSYNC
+ case EL2NSYNC: return DRWAV_ERROR;
+ #endif
+ #ifdef EL3HLT
+ case EL3HLT: return DRWAV_ERROR;
+ #endif
+ #ifdef EL3RST
+ case EL3RST: return DRWAV_ERROR;
+ #endif
+ #ifdef ELNRNG
+ case ELNRNG: return DRWAV_OUT_OF_RANGE;
+ #endif
+ #ifdef EUNATCH
+ case EUNATCH: return DRWAV_ERROR;
+ #endif
+ #ifdef ENOCSI
+ case ENOCSI: return DRWAV_ERROR;
+ #endif
+ #ifdef EL2HLT
+ case EL2HLT: return DRWAV_ERROR;
+ #endif
+ #ifdef EBADE
+ case EBADE: return DRWAV_ERROR;
+ #endif
+ #ifdef EBADR
+ case EBADR: return DRWAV_ERROR;
+ #endif
+ #ifdef EXFULL
+ case EXFULL: return DRWAV_ERROR;
+ #endif
+ #ifdef ENOANO
+ case ENOANO: return DRWAV_ERROR;
+ #endif
+ #ifdef EBADRQC
+ case EBADRQC: return DRWAV_ERROR;
+ #endif
+ #ifdef EBADSLT
+ case EBADSLT: return DRWAV_ERROR;
+ #endif
+ #ifdef EBFONT
+ case EBFONT: return DRWAV_INVALID_FILE;
+ #endif
+ #ifdef ENOSTR
+ case ENOSTR: return DRWAV_ERROR;
+ #endif
+ #ifdef ENODATA
+ case ENODATA: return DRWAV_NO_DATA_AVAILABLE;
+ #endif
+ #ifdef ETIME
+ case ETIME: return DRWAV_TIMEOUT;
+ #endif
+ #ifdef ENOSR
+ case ENOSR: return DRWAV_NO_DATA_AVAILABLE;
+ #endif
+ #ifdef ENONET
+ case ENONET: return DRWAV_NO_NETWORK;
+ #endif
+ #ifdef ENOPKG
+ case ENOPKG: return DRWAV_ERROR;
+ #endif
+ #ifdef EREMOTE
+ case EREMOTE: return DRWAV_ERROR;
+ #endif
+ #ifdef ENOLINK
+ case ENOLINK: return DRWAV_ERROR;
+ #endif
+ #ifdef EADV
+ case EADV: return DRWAV_ERROR;
+ #endif
+ #ifdef ESRMNT
+ case ESRMNT: return DRWAV_ERROR;
+ #endif
+ #ifdef ECOMM
+ case ECOMM: return DRWAV_ERROR;
+ #endif
+ #ifdef EPROTO
+ case EPROTO: return DRWAV_ERROR;
+ #endif
+ #ifdef EMULTIHOP
+ case EMULTIHOP: return DRWAV_ERROR;
+ #endif
+ #ifdef EDOTDOT
+ case EDOTDOT: return DRWAV_ERROR;
+ #endif
+ #ifdef EBADMSG
+ case EBADMSG: return DRWAV_BAD_MESSAGE;
+ #endif
+ #ifdef EOVERFLOW
+ case EOVERFLOW: return DRWAV_TOO_BIG;
+ #endif
+ #ifdef ENOTUNIQ
+ case ENOTUNIQ: return DRWAV_NOT_UNIQUE;
+ #endif
+ #ifdef EBADFD
+ case EBADFD: return DRWAV_ERROR;
+ #endif
+ #ifdef EREMCHG
+ case EREMCHG: return DRWAV_ERROR;
+ #endif
+ #ifdef ELIBACC
+ case ELIBACC: return DRWAV_ACCESS_DENIED;
+ #endif
+ #ifdef ELIBBAD
+ case ELIBBAD: return DRWAV_INVALID_FILE;
+ #endif
+ #ifdef ELIBSCN
+ case ELIBSCN: return DRWAV_INVALID_FILE;
+ #endif
+ #ifdef ELIBMAX
+ case ELIBMAX: return DRWAV_ERROR;
+ #endif
+ #ifdef ELIBEXEC
+ case ELIBEXEC: return DRWAV_ERROR;
+ #endif
+ #ifdef EILSEQ
+ case EILSEQ: return DRWAV_INVALID_DATA;
+ #endif
+ #ifdef ERESTART
+ case ERESTART: return DRWAV_ERROR;
+ #endif
+ #ifdef ESTRPIPE
+ case ESTRPIPE: return DRWAV_ERROR;
+ #endif
+ #ifdef EUSERS
+ case EUSERS: return DRWAV_ERROR;
+ #endif
+ #ifdef ENOTSOCK
+ case ENOTSOCK: return DRWAV_NOT_SOCKET;
+ #endif
+ #ifdef EDESTADDRREQ
+ case EDESTADDRREQ: return DRWAV_NO_ADDRESS;
+ #endif
+ #ifdef EMSGSIZE
+ case EMSGSIZE: return DRWAV_TOO_BIG;
+ #endif
+ #ifdef EPROTOTYPE
+ case EPROTOTYPE: return DRWAV_BAD_PROTOCOL;
+ #endif
+ #ifdef ENOPROTOOPT
+ case ENOPROTOOPT: return DRWAV_PROTOCOL_UNAVAILABLE;
+ #endif
+ #ifdef EPROTONOSUPPORT
+ case EPROTONOSUPPORT: return DRWAV_PROTOCOL_NOT_SUPPORTED;
+ #endif
+ #ifdef ESOCKTNOSUPPORT
+ case ESOCKTNOSUPPORT: return DRWAV_SOCKET_NOT_SUPPORTED;
+ #endif
+ #ifdef EOPNOTSUPP
+ case EOPNOTSUPP: return DRWAV_INVALID_OPERATION;
+ #endif
+ #ifdef EPFNOSUPPORT
+ case EPFNOSUPPORT: return DRWAV_PROTOCOL_FAMILY_NOT_SUPPORTED;
+ #endif
+ #ifdef EAFNOSUPPORT
+ case EAFNOSUPPORT: return DRWAV_ADDRESS_FAMILY_NOT_SUPPORTED;
+ #endif
+ #ifdef EADDRINUSE
+ case EADDRINUSE: return DRWAV_ALREADY_IN_USE;
+ #endif
+ #ifdef EADDRNOTAVAIL
+ case EADDRNOTAVAIL: return DRWAV_ERROR;
+ #endif
+ #ifdef ENETDOWN
+ case ENETDOWN: return DRWAV_NO_NETWORK;
+ #endif
+ #ifdef ENETUNREACH
+ case ENETUNREACH: return DRWAV_NO_NETWORK;
+ #endif
+ #ifdef ENETRESET
+ case ENETRESET: return DRWAV_NO_NETWORK;
+ #endif
+ #ifdef ECONNABORTED
+ case ECONNABORTED: return DRWAV_NO_NETWORK;
+ #endif
+ #ifdef ECONNRESET
+ case ECONNRESET: return DRWAV_CONNECTION_RESET;
+ #endif
+ #ifdef ENOBUFS
+ case ENOBUFS: return DRWAV_NO_SPACE;
+ #endif
+ #ifdef EISCONN
+ case EISCONN: return DRWAV_ALREADY_CONNECTED;
+ #endif
+ #ifdef ENOTCONN
+ case ENOTCONN: return DRWAV_NOT_CONNECTED;
+ #endif
+ #ifdef ESHUTDOWN
+ case ESHUTDOWN: return DRWAV_ERROR;
+ #endif
+ #ifdef ETOOMANYREFS
+ case ETOOMANYREFS: return DRWAV_ERROR;
+ #endif
+ #ifdef ETIMEDOUT
+ case ETIMEDOUT: return DRWAV_TIMEOUT;
+ #endif
+ #ifdef ECONNREFUSED
+ case ECONNREFUSED: return DRWAV_CONNECTION_REFUSED;
+ #endif
+ #ifdef EHOSTDOWN
+ case EHOSTDOWN: return DRWAV_NO_HOST;
+ #endif
+ #ifdef EHOSTUNREACH
+ case EHOSTUNREACH: return DRWAV_NO_HOST;
+ #endif
+ #ifdef EALREADY
+ case EALREADY: return DRWAV_IN_PROGRESS;
+ #endif
+ #ifdef EINPROGRESS
+ case EINPROGRESS: return DRWAV_IN_PROGRESS;
+ #endif
+ #ifdef ESTALE
+ case ESTALE: return DRWAV_INVALID_FILE;
+ #endif
+ #ifdef EUCLEAN
+ case EUCLEAN: return DRWAV_ERROR;
+ #endif
+ #ifdef ENOTNAM
+ case ENOTNAM: return DRWAV_ERROR;
+ #endif
+ #ifdef ENAVAIL
+ case ENAVAIL: return DRWAV_ERROR;
+ #endif
+ #ifdef EISNAM
+ case EISNAM: return DRWAV_ERROR;
+ #endif
+ #ifdef EREMOTEIO
+ case EREMOTEIO: return DRWAV_IO_ERROR;
+ #endif
+ #ifdef EDQUOT
+ case EDQUOT: return DRWAV_NO_SPACE;
+ #endif
+ #ifdef ENOMEDIUM
+ case ENOMEDIUM: return DRWAV_DOES_NOT_EXIST;
+ #endif
+ #ifdef EMEDIUMTYPE
+ case EMEDIUMTYPE: return DRWAV_ERROR;
+ #endif
+ #ifdef ECANCELED
+ case ECANCELED: return DRWAV_CANCELLED;
+ #endif
+ #ifdef ENOKEY
+ case ENOKEY: return DRWAV_ERROR;
+ #endif
+ #ifdef EKEYEXPIRED
+ case EKEYEXPIRED: return DRWAV_ERROR;
+ #endif
+ #ifdef EKEYREVOKED
+ case EKEYREVOKED: return DRWAV_ERROR;
+ #endif
+ #ifdef EKEYREJECTED
+ case EKEYREJECTED: return DRWAV_ERROR;
+ #endif
+ #ifdef EOWNERDEAD
+ case EOWNERDEAD: return DRWAV_ERROR;
+ #endif
+ #ifdef ENOTRECOVERABLE
+ case ENOTRECOVERABLE: return DRWAV_ERROR;
+ #endif
+ #ifdef ERFKILL
+ case ERFKILL: return DRWAV_ERROR;
+ #endif
+ #ifdef EHWPOISON
+ case EHWPOISON: return DRWAV_ERROR;
+ #endif
+ default: return DRWAV_ERROR;
+ }
+}
+
+static drwav_result drwav_fopen(FILE** ppFile, const char* pFilePath, const char* pOpenMode)
+{
+#if _MSC_VER && _MSC_VER >= 1400
+ errno_t err;
+#endif
+
+ if (ppFile != NULL) {
+ *ppFile = NULL; /* Safety. */
+ }
+
+ if (pFilePath == NULL || pOpenMode == NULL || ppFile == NULL) {
+ return DRWAV_INVALID_ARGS;
+ }
+
+#if _MSC_VER && _MSC_VER >= 1400
+ err = fopen_s(ppFile, pFilePath, pOpenMode);
+ if (err != 0) {
+ return drwav_result_from_errno(err);
+ }
+#else
+#if defined(_WIN32) || defined(__APPLE__)
+ *ppFile = fopen(pFilePath, pOpenMode);
+#else
+ #if defined(_FILE_OFFSET_BITS) && _FILE_OFFSET_BITS == 64 && defined(_LARGEFILE64_SOURCE)
+ *ppFile = fopen64(pFilePath, pOpenMode);
+ #else
+ *ppFile = fopen(pFilePath, pOpenMode);
+ #endif
+#endif
+ if (*ppFile == NULL) {
+ drwav_result result = drwav_result_from_errno(errno);
+ if (result == DRWAV_SUCCESS) {
+ result = DRWAV_ERROR; /* Just a safety check to make sure we never ever return success when pFile == NULL. */
+ }
+
+ return result;
+ }
+#endif
+
+ return DRWAV_SUCCESS;
+}
+
+/*
+_wfopen() isn't always available in all compilation environments.
+
+ * Windows only.
+ * MSVC seems to support it universally as far back as VC6 from what I can tell (haven't checked further back).
+ * MinGW-64 (both 32- and 64-bit) seems to support it.
+ * MinGW wraps it in !defined(__STRICT_ANSI__).
+ * OpenWatcom wraps it in !defined(_NO_EXT_KEYS).
+
+This can be reviewed as compatibility issues arise. The preference is to use _wfopen_s() and _wfopen() as opposed to the wcsrtombs()
+fallback, so if you notice your compiler not detecting this properly I'm happy to look at adding support.
+*/
+#if defined(_WIN32)
+ #if defined(_MSC_VER) || defined(__MINGW64__) || (!defined(__STRICT_ANSI__) && !defined(_NO_EXT_KEYS))
+ #define DRWAV_HAS_WFOPEN
+ #endif
+#endif
+
+static drwav_result drwav_wfopen(FILE** ppFile, const wchar_t* pFilePath, const wchar_t* pOpenMode, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (ppFile != NULL) {
+ *ppFile = NULL; /* Safety. */
+ }
+
+ if (pFilePath == NULL || pOpenMode == NULL || ppFile == NULL) {
+ return DRWAV_INVALID_ARGS;
+ }
+
+#if defined(DRWAV_HAS_WFOPEN)
+ {
+ /* Use _wfopen() on Windows. */
+ #if defined(_MSC_VER) && _MSC_VER >= 1400
+ errno_t err = _wfopen_s(ppFile, pFilePath, pOpenMode);
+ if (err != 0) {
+ return drwav_result_from_errno(err);
+ }
+ #else
+ *ppFile = _wfopen(pFilePath, pOpenMode);
+ if (*ppFile == NULL) {
+ return drwav_result_from_errno(errno);
+ }
+ #endif
+ (void)pAllocationCallbacks;
+ }
+#else
+ /*
+ Use fopen() on anything other than Windows. Requires a conversion. This is annoying because fopen() is locale specific. The only real way I can
+ think of to do this is with wcsrtombs(). Note that wcstombs() is apparently not thread-safe because it uses a static global mbstate_t object for
+ maintaining state. I've checked this with -std=c89 and it works, but if somebody get's a compiler error I'll look into improving compatibility.
+ */
+ {
+ mbstate_t mbs;
+ size_t lenMB;
+ const wchar_t* pFilePathTemp = pFilePath;
+ char* pFilePathMB = NULL;
+ char pOpenModeMB[32] = {0};
+
+ /* Get the length first. */
+ DRWAV_ZERO_OBJECT(&mbs);
+ lenMB = wcsrtombs(NULL, &pFilePathTemp, 0, &mbs);
+ if (lenMB == (size_t)-1) {
+ return drwav_result_from_errno(errno);
+ }
+
+ pFilePathMB = (char*)drwav__malloc_from_callbacks(lenMB + 1, pAllocationCallbacks);
+ if (pFilePathMB == NULL) {
+ return DRWAV_OUT_OF_MEMORY;
+ }
+
+ pFilePathTemp = pFilePath;
+ DRWAV_ZERO_OBJECT(&mbs);
+ wcsrtombs(pFilePathMB, &pFilePathTemp, lenMB + 1, &mbs);
+
+ /* The open mode should always consist of ASCII characters so we should be able to do a trivial conversion. */
+ {
+ size_t i = 0;
+ for (;;) {
+ if (pOpenMode[i] == 0) {
+ pOpenModeMB[i] = '\0';
+ break;
+ }
+
+ pOpenModeMB[i] = (char)pOpenMode[i];
+ i += 1;
+ }
+ }
+
+ *ppFile = fopen(pFilePathMB, pOpenModeMB);
+
+ drwav__free_from_callbacks(pFilePathMB, pAllocationCallbacks);
+ }
+
+ if (*ppFile == NULL) {
+ return DRWAV_ERROR;
+ }
+#endif
+
+ return DRWAV_SUCCESS;
+}
+
+
+static size_t drwav__on_read_stdio(void* pUserData, void* pBufferOut, size_t bytesToRead)
+{
+ return fread(pBufferOut, 1, bytesToRead, (FILE*)pUserData);
+}
+
+static size_t drwav__on_write_stdio(void* pUserData, const void* pData, size_t bytesToWrite)
+{
+ return fwrite(pData, 1, bytesToWrite, (FILE*)pUserData);
+}
+
+static drwav_bool32 drwav__on_seek_stdio(void* pUserData, int offset, drwav_seek_origin origin)
+{
+ return fseek((FILE*)pUserData, offset, (origin == drwav_seek_origin_current) ? SEEK_CUR : SEEK_SET) == 0;
+}
+
+DRWAV_API drwav_bool32 drwav_init_file(drwav* pWav, const char* filename, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ return drwav_init_file_ex(pWav, filename, NULL, NULL, 0, pAllocationCallbacks);
+}
+
+
+static drwav_bool32 drwav_init_file__internal_FILE(drwav* pWav, FILE* pFile, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav_bool32 result;
+
+ result = drwav_preinit(pWav, drwav__on_read_stdio, drwav__on_seek_stdio, (void*)pFile, pAllocationCallbacks);
+ if (result != DRWAV_TRUE) {
+ fclose(pFile);
+ return result;
+ }
+
+ result = drwav_init__internal(pWav, onChunk, pChunkUserData, flags);
+ if (result != DRWAV_TRUE) {
+ fclose(pFile);
+ return result;
+ }
+
+ return DRWAV_TRUE;
+}
+
+DRWAV_API drwav_bool32 drwav_init_file_ex(drwav* pWav, const char* filename, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ FILE* pFile;
+ if (drwav_fopen(&pFile, filename, "rb") != DRWAV_SUCCESS) {
+ return DRWAV_FALSE;
+ }
+
+ /* This takes ownership of the FILE* object. */
+ return drwav_init_file__internal_FILE(pWav, pFile, onChunk, pChunkUserData, flags, pAllocationCallbacks);
+}
+
+DRWAV_API drwav_bool32 drwav_init_file_w(drwav* pWav, const wchar_t* filename, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ return drwav_init_file_ex_w(pWav, filename, NULL, NULL, 0, pAllocationCallbacks);
+}
+
+DRWAV_API drwav_bool32 drwav_init_file_ex_w(drwav* pWav, const wchar_t* filename, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ FILE* pFile;
+ if (drwav_wfopen(&pFile, filename, L"rb", pAllocationCallbacks) != DRWAV_SUCCESS) {
+ return DRWAV_FALSE;
+ }
+
+ /* This takes ownership of the FILE* object. */
+ return drwav_init_file__internal_FILE(pWav, pFile, onChunk, pChunkUserData, flags, pAllocationCallbacks);
+}
+
+
+static drwav_bool32 drwav_init_file_write__internal_FILE(drwav* pWav, FILE* pFile, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_bool32 isSequential, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav_bool32 result;
+
+ result = drwav_preinit_write(pWav, pFormat, isSequential, drwav__on_write_stdio, drwav__on_seek_stdio, (void*)pFile, pAllocationCallbacks);
+ if (result != DRWAV_TRUE) {
+ fclose(pFile);
+ return result;
+ }
+
+ result = drwav_init_write__internal(pWav, pFormat, totalSampleCount);
+ if (result != DRWAV_TRUE) {
+ fclose(pFile);
+ return result;
+ }
+
+ return DRWAV_TRUE;
+}
+
+static drwav_bool32 drwav_init_file_write__internal(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_bool32 isSequential, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ FILE* pFile;
+ if (drwav_fopen(&pFile, filename, "wb") != DRWAV_SUCCESS) {
+ return DRWAV_FALSE;
+ }
+
+ /* This takes ownership of the FILE* object. */
+ return drwav_init_file_write__internal_FILE(pWav, pFile, pFormat, totalSampleCount, isSequential, pAllocationCallbacks);
+}
+
+static drwav_bool32 drwav_init_file_write_w__internal(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_bool32 isSequential, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ FILE* pFile;
+ if (drwav_wfopen(&pFile, filename, L"wb", pAllocationCallbacks) != DRWAV_SUCCESS) {
+ return DRWAV_FALSE;
+ }
+
+ /* This takes ownership of the FILE* object. */
+ return drwav_init_file_write__internal_FILE(pWav, pFile, pFormat, totalSampleCount, isSequential, pAllocationCallbacks);
+}
+
+DRWAV_API drwav_bool32 drwav_init_file_write(drwav* pWav, const char* filename, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ return drwav_init_file_write__internal(pWav, filename, pFormat, 0, DRWAV_FALSE, pAllocationCallbacks);
+}
+
+DRWAV_API drwav_bool32 drwav_init_file_write_sequential(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ return drwav_init_file_write__internal(pWav, filename, pFormat, totalSampleCount, DRWAV_TRUE, pAllocationCallbacks);
+}
+
+DRWAV_API drwav_bool32 drwav_init_file_write_sequential_pcm_frames(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (pFormat == NULL) {
+ return DRWAV_FALSE;
+ }
+
+ return drwav_init_file_write_sequential(pWav, filename, pFormat, totalPCMFrameCount*pFormat->channels, pAllocationCallbacks);
+}
+
+DRWAV_API drwav_bool32 drwav_init_file_write_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ return drwav_init_file_write_w__internal(pWav, filename, pFormat, 0, DRWAV_FALSE, pAllocationCallbacks);
+}
+
+DRWAV_API drwav_bool32 drwav_init_file_write_sequential_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ return drwav_init_file_write_w__internal(pWav, filename, pFormat, totalSampleCount, DRWAV_TRUE, pAllocationCallbacks);
+}
+
+DRWAV_API drwav_bool32 drwav_init_file_write_sequential_pcm_frames_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (pFormat == NULL) {
+ return DRWAV_FALSE;
+ }
+
+ return drwav_init_file_write_sequential_w(pWav, filename, pFormat, totalPCMFrameCount*pFormat->channels, pAllocationCallbacks);
+}
+#endif /* DR_WAV_NO_STDIO */
+
+
+static size_t drwav__on_read_memory(void* pUserData, void* pBufferOut, size_t bytesToRead)
+{
+ drwav* pWav = (drwav*)pUserData;
+ size_t bytesRemaining;
+
+ DRWAV_ASSERT(pWav != NULL);
+ DRWAV_ASSERT(pWav->memoryStream.dataSize >= pWav->memoryStream.currentReadPos);
+
+ bytesRemaining = pWav->memoryStream.dataSize - pWav->memoryStream.currentReadPos;
+ if (bytesToRead > bytesRemaining) {
+ bytesToRead = bytesRemaining;
+ }
+
+ if (bytesToRead > 0) {
+ DRWAV_COPY_MEMORY(pBufferOut, pWav->memoryStream.data + pWav->memoryStream.currentReadPos, bytesToRead);
+ pWav->memoryStream.currentReadPos += bytesToRead;
+ }
+
+ return bytesToRead;
+}
+
+static drwav_bool32 drwav__on_seek_memory(void* pUserData, int offset, drwav_seek_origin origin)
+{
+ drwav* pWav = (drwav*)pUserData;
+ DRWAV_ASSERT(pWav != NULL);
+
+ if (origin == drwav_seek_origin_current) {
+ if (offset > 0) {
+ if (pWav->memoryStream.currentReadPos + offset > pWav->memoryStream.dataSize) {
+ return DRWAV_FALSE; /* Trying to seek too far forward. */
+ }
+ } else {
+ if (pWav->memoryStream.currentReadPos < (size_t)-offset) {
+ return DRWAV_FALSE; /* Trying to seek too far backwards. */
+ }
+ }
+
+ /* This will never underflow thanks to the clamps above. */
+ pWav->memoryStream.currentReadPos += offset;
+ } else {
+ if ((drwav_uint32)offset <= pWav->memoryStream.dataSize) {
+ pWav->memoryStream.currentReadPos = offset;
+ } else {
+ return DRWAV_FALSE; /* Trying to seek too far forward. */
+ }
+ }
+
+ return DRWAV_TRUE;
+}
+
+static size_t drwav__on_write_memory(void* pUserData, const void* pDataIn, size_t bytesToWrite)
+{
+ drwav* pWav = (drwav*)pUserData;
+ size_t bytesRemaining;
+
+ DRWAV_ASSERT(pWav != NULL);
+ DRWAV_ASSERT(pWav->memoryStreamWrite.dataCapacity >= pWav->memoryStreamWrite.currentWritePos);
+
+ bytesRemaining = pWav->memoryStreamWrite.dataCapacity - pWav->memoryStreamWrite.currentWritePos;
+ if (bytesRemaining < bytesToWrite) {
+ /* Need to reallocate. */
+ void* pNewData;
+ size_t newDataCapacity = (pWav->memoryStreamWrite.dataCapacity == 0) ? 256 : pWav->memoryStreamWrite.dataCapacity * 2;
+
+ /* If doubling wasn't enough, just make it the minimum required size to write the data. */
+ if ((newDataCapacity - pWav->memoryStreamWrite.currentWritePos) < bytesToWrite) {
+ newDataCapacity = pWav->memoryStreamWrite.currentWritePos + bytesToWrite;
+ }
+
+ pNewData = drwav__realloc_from_callbacks(*pWav->memoryStreamWrite.ppData, newDataCapacity, pWav->memoryStreamWrite.dataCapacity, &pWav->allocationCallbacks);
+ if (pNewData == NULL) {
+ return 0;
+ }
+
+ *pWav->memoryStreamWrite.ppData = pNewData;
+ pWav->memoryStreamWrite.dataCapacity = newDataCapacity;
+ }
+
+ DRWAV_COPY_MEMORY(((drwav_uint8*)(*pWav->memoryStreamWrite.ppData)) + pWav->memoryStreamWrite.currentWritePos, pDataIn, bytesToWrite);
+
+ pWav->memoryStreamWrite.currentWritePos += bytesToWrite;
+ if (pWav->memoryStreamWrite.dataSize < pWav->memoryStreamWrite.currentWritePos) {
+ pWav->memoryStreamWrite.dataSize = pWav->memoryStreamWrite.currentWritePos;
+ }
+
+ *pWav->memoryStreamWrite.pDataSize = pWav->memoryStreamWrite.dataSize;
+
+ return bytesToWrite;
+}
+
+static drwav_bool32 drwav__on_seek_memory_write(void* pUserData, int offset, drwav_seek_origin origin)
+{
+ drwav* pWav = (drwav*)pUserData;
+ DRWAV_ASSERT(pWav != NULL);
+
+ if (origin == drwav_seek_origin_current) {
+ if (offset > 0) {
+ if (pWav->memoryStreamWrite.currentWritePos + offset > pWav->memoryStreamWrite.dataSize) {
+ offset = (int)(pWav->memoryStreamWrite.dataSize - pWav->memoryStreamWrite.currentWritePos); /* Trying to seek too far forward. */
+ }
+ } else {
+ if (pWav->memoryStreamWrite.currentWritePos < (size_t)-offset) {
+ offset = -(int)pWav->memoryStreamWrite.currentWritePos; /* Trying to seek too far backwards. */
+ }
+ }
+
+ /* This will never underflow thanks to the clamps above. */
+ pWav->memoryStreamWrite.currentWritePos += offset;
+ } else {
+ if ((drwav_uint32)offset <= pWav->memoryStreamWrite.dataSize) {
+ pWav->memoryStreamWrite.currentWritePos = offset;
+ } else {
+ pWav->memoryStreamWrite.currentWritePos = pWav->memoryStreamWrite.dataSize; /* Trying to seek too far forward. */
+ }
+ }
+
+ return DRWAV_TRUE;
+}
+
+DRWAV_API drwav_bool32 drwav_init_memory(drwav* pWav, const void* data, size_t dataSize, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ return drwav_init_memory_ex(pWav, data, dataSize, NULL, NULL, 0, pAllocationCallbacks);
+}
+
+DRWAV_API drwav_bool32 drwav_init_memory_ex(drwav* pWav, const void* data, size_t dataSize, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (data == NULL || dataSize == 0) {
+ return DRWAV_FALSE;
+ }
+
+ if (!drwav_preinit(pWav, drwav__on_read_memory, drwav__on_seek_memory, pWav, pAllocationCallbacks)) {
+ return DRWAV_FALSE;
+ }
+
+ pWav->memoryStream.data = (const drwav_uint8*)data;
+ pWav->memoryStream.dataSize = dataSize;
+ pWav->memoryStream.currentReadPos = 0;
+
+ return drwav_init__internal(pWav, onChunk, pChunkUserData, flags);
+}
+
+
+static drwav_bool32 drwav_init_memory_write__internal(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_bool32 isSequential, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (ppData == NULL || pDataSize == NULL) {
+ return DRWAV_FALSE;
+ }
+
+ *ppData = NULL; /* Important because we're using realloc()! */
+ *pDataSize = 0;
+
+ if (!drwav_preinit_write(pWav, pFormat, isSequential, drwav__on_write_memory, drwav__on_seek_memory_write, pWav, pAllocationCallbacks)) {
+ return DRWAV_FALSE;
+ }
+
+ pWav->memoryStreamWrite.ppData = ppData;
+ pWav->memoryStreamWrite.pDataSize = pDataSize;
+ pWav->memoryStreamWrite.dataSize = 0;
+ pWav->memoryStreamWrite.dataCapacity = 0;
+ pWav->memoryStreamWrite.currentWritePos = 0;
+
+ return drwav_init_write__internal(pWav, pFormat, totalSampleCount);
+}
+
+DRWAV_API drwav_bool32 drwav_init_memory_write(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ return drwav_init_memory_write__internal(pWav, ppData, pDataSize, pFormat, 0, DRWAV_FALSE, pAllocationCallbacks);
+}
+
+DRWAV_API drwav_bool32 drwav_init_memory_write_sequential(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ return drwav_init_memory_write__internal(pWav, ppData, pDataSize, pFormat, totalSampleCount, DRWAV_TRUE, pAllocationCallbacks);
+}
+
+DRWAV_API drwav_bool32 drwav_init_memory_write_sequential_pcm_frames(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (pFormat == NULL) {
+ return DRWAV_FALSE;
+ }
+
+ return drwav_init_memory_write_sequential(pWav, ppData, pDataSize, pFormat, totalPCMFrameCount*pFormat->channels, pAllocationCallbacks);
+}
+
+
+
+DRWAV_API drwav_result drwav_uninit(drwav* pWav)
+{
+ drwav_result result = DRWAV_SUCCESS;
+
+ if (pWav == NULL) {
+ return DRWAV_INVALID_ARGS;
+ }
+
+ /*
+ If the drwav object was opened in write mode we'll need to finalize a few things:
+ - Make sure the "data" chunk is aligned to 16-bits for RIFF containers, or 64 bits for W64 containers.
+ - Set the size of the "data" chunk.
+ */
+ if (pWav->onWrite != NULL) {
+ drwav_uint32 paddingSize = 0;
+
+ /* Padding. Do not adjust pWav->dataChunkDataSize - this should not include the padding. */
+ if (pWav->container == drwav_container_riff || pWav->container == drwav_container_rf64) {
+ paddingSize = drwav__chunk_padding_size_riff(pWav->dataChunkDataSize);
+ } else {
+ paddingSize = drwav__chunk_padding_size_w64(pWav->dataChunkDataSize);
+ }
+
+ if (paddingSize > 0) {
+ drwav_uint64 paddingData = 0;
+ drwav__write(pWav, &paddingData, paddingSize); /* Byte order does not matter for this. */
+ }
+
+ /*
+ Chunk sizes. When using sequential mode, these will have been filled in at initialization time. We only need
+ to do this when using non-sequential mode.
+ */
+ if (pWav->onSeek && !pWav->isSequentialWrite) {
+ if (pWav->container == drwav_container_riff) {
+ /* The "RIFF" chunk size. */
+ if (pWav->onSeek(pWav->pUserData, 4, drwav_seek_origin_start)) {
+ drwav_uint32 riffChunkSize = drwav__riff_chunk_size_riff(pWav->dataChunkDataSize);
+ drwav__write_u32ne_to_le(pWav, riffChunkSize);
+ }
+
+ /* the "data" chunk size. */
+ if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos + 4, drwav_seek_origin_start)) {
+ drwav_uint32 dataChunkSize = drwav__data_chunk_size_riff(pWav->dataChunkDataSize);
+ drwav__write_u32ne_to_le(pWav, dataChunkSize);
+ }
+ } else if (pWav->container == drwav_container_w64) {
+ /* The "RIFF" chunk size. */
+ if (pWav->onSeek(pWav->pUserData, 16, drwav_seek_origin_start)) {
+ drwav_uint64 riffChunkSize = drwav__riff_chunk_size_w64(pWav->dataChunkDataSize);
+ drwav__write_u64ne_to_le(pWav, riffChunkSize);
+ }
+
+ /* The "data" chunk size. */
+ if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos + 16, drwav_seek_origin_start)) {
+ drwav_uint64 dataChunkSize = drwav__data_chunk_size_w64(pWav->dataChunkDataSize);
+ drwav__write_u64ne_to_le(pWav, dataChunkSize);
+ }
+ } else if (pWav->container == drwav_container_rf64) {
+ /* We only need to update the ds64 chunk. The "RIFF" and "data" chunks always have their sizes set to 0xFFFFFFFF for RF64. */
+ int ds64BodyPos = 12 + 8;
+
+ /* The "RIFF" chunk size. */
+ if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 0, drwav_seek_origin_start)) {
+ drwav_uint64 riffChunkSize = drwav__riff_chunk_size_rf64(pWav->dataChunkDataSize);
+ drwav__write_u64ne_to_le(pWav, riffChunkSize);
+ }
+
+ /* The "data" chunk size. */
+ if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 8, drwav_seek_origin_start)) {
+ drwav_uint64 dataChunkSize = drwav__data_chunk_size_rf64(pWav->dataChunkDataSize);
+ drwav__write_u64ne_to_le(pWav, dataChunkSize);
+ }
+ }
+ }
+
+ /* Validation for sequential mode. */
+ if (pWav->isSequentialWrite) {
+ if (pWav->dataChunkDataSize != pWav->dataChunkDataSizeTargetWrite) {
+ result = DRWAV_INVALID_FILE;
+ }
+ }
+ }
+
+#ifndef DR_WAV_NO_STDIO
+ /*
+ If we opened the file with drwav_open_file() we will want to close the file handle. We can know whether or not drwav_open_file()
+ was used by looking at the onRead and onSeek callbacks.
+ */
+ if (pWav->onRead == drwav__on_read_stdio || pWav->onWrite == drwav__on_write_stdio) {
+ fclose((FILE*)pWav->pUserData);
+ }
+#endif
+
+ return result;
+}
+
+
+
+DRWAV_API size_t drwav_read_raw(drwav* pWav, size_t bytesToRead, void* pBufferOut)
+{
+ size_t bytesRead;
+
+ if (pWav == NULL || bytesToRead == 0) {
+ return 0;
+ }
+
+ if (bytesToRead > pWav->bytesRemaining) {
+ bytesToRead = (size_t)pWav->bytesRemaining;
+ }
+
+ if (pBufferOut != NULL) {
+ bytesRead = pWav->onRead(pWav->pUserData, pBufferOut, bytesToRead);
+ } else {
+ /* We need to seek. If we fail, we need to read-and-discard to make sure we get a good byte count. */
+ bytesRead = 0;
+ while (bytesRead < bytesToRead) {
+ size_t bytesToSeek = (bytesToRead - bytesRead);
+ if (bytesToSeek > 0x7FFFFFFF) {
+ bytesToSeek = 0x7FFFFFFF;
+ }
+
+ if (pWav->onSeek(pWav->pUserData, (int)bytesToSeek, drwav_seek_origin_current) == DRWAV_FALSE) {
+ break;
+ }
+
+ bytesRead += bytesToSeek;
+ }
+
+ /* When we get here we may need to read-and-discard some data. */
+ while (bytesRead < bytesToRead) {
+ drwav_uint8 buffer[4096];
+ size_t bytesSeeked;
+ size_t bytesToSeek = (bytesToRead - bytesRead);
+ if (bytesToSeek > sizeof(buffer)) {
+ bytesToSeek = sizeof(buffer);
+ }
+
+ bytesSeeked = pWav->onRead(pWav->pUserData, buffer, bytesToSeek);
+ bytesRead += bytesSeeked;
+
+ if (bytesSeeked < bytesToSeek) {
+ break; /* Reached the end. */
+ }
+ }
+ }
+
+ pWav->bytesRemaining -= bytesRead;
+ return bytesRead;
+}
+
+
+
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_le(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut)
+{
+ drwav_uint32 bytesPerFrame;
+ drwav_uint64 bytesToRead; /* Intentionally uint64 instead of size_t so we can do a check that we're not reading too much on 32-bit builds. */
+
+ if (pWav == NULL || framesToRead == 0) {
+ return 0;
+ }
+
+ /* Cannot use this function for compressed formats. */
+ if (drwav__is_compressed_format_tag(pWav->translatedFormatTag)) {
+ return 0;
+ }
+
+ bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
+ if (bytesPerFrame == 0) {
+ return 0;
+ }
+
+ /* Don't try to read more samples than can potentially fit in the output buffer. */
+ bytesToRead = framesToRead * bytesPerFrame;
+ if (bytesToRead > DRWAV_SIZE_MAX) {
+ bytesToRead = (DRWAV_SIZE_MAX / bytesPerFrame) * bytesPerFrame; /* Round the number of bytes to read to a clean frame boundary. */
+ }
+
+ /*
+ Doing an explicit check here just to make it clear that we don't want to be attempt to read anything if there's no bytes to read. There
+ *could* be a time where it evaluates to 0 due to overflowing.
+ */
+ if (bytesToRead == 0) {
+ return 0;
+ }
+
+ return drwav_read_raw(pWav, (size_t)bytesToRead, pBufferOut) / bytesPerFrame;
+}
+
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_be(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut)
+{
+ drwav_uint64 framesRead = drwav_read_pcm_frames_le(pWav, framesToRead, pBufferOut);
+
+ if (pBufferOut != NULL) {
+ drwav__bswap_samples(pBufferOut, framesRead*pWav->channels, drwav_get_bytes_per_pcm_frame(pWav)/pWav->channels, pWav->translatedFormatTag);
+ }
+
+ return framesRead;
+}
+
+DRWAV_API drwav_uint64 drwav_read_pcm_frames(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut)
+{
+ if (drwav__is_little_endian()) {
+ return drwav_read_pcm_frames_le(pWav, framesToRead, pBufferOut);
+ } else {
+ return drwav_read_pcm_frames_be(pWav, framesToRead, pBufferOut);
+ }
+}
+
+
+
+DRWAV_API drwav_bool32 drwav_seek_to_first_pcm_frame(drwav* pWav)
+{
+ if (pWav->onWrite != NULL) {
+ return DRWAV_FALSE; /* No seeking in write mode. */
+ }
+
+ if (!pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos, drwav_seek_origin_start)) {
+ return DRWAV_FALSE;
+ }
+
+ if (drwav__is_compressed_format_tag(pWav->translatedFormatTag)) {
+ pWav->compressed.iCurrentPCMFrame = 0;
+
+ /* Cached data needs to be cleared for compressed formats. */
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) {
+ DRWAV_ZERO_OBJECT(&pWav->msadpcm);
+ } else if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) {
+ DRWAV_ZERO_OBJECT(&pWav->ima);
+ } else {
+ DRWAV_ASSERT(DRWAV_FALSE); /* If this assertion is triggered it means I've implemented a new compressed format but forgot to add a branch for it here. */
+ }
+ }
+
+ pWav->bytesRemaining = pWav->dataChunkDataSize;
+ return DRWAV_TRUE;
+}
+
+DRWAV_API drwav_bool32 drwav_seek_to_pcm_frame(drwav* pWav, drwav_uint64 targetFrameIndex)
+{
+ /* Seeking should be compatible with wave files > 2GB. */
+
+ if (pWav == NULL || pWav->onSeek == NULL) {
+ return DRWAV_FALSE;
+ }
+
+ /* No seeking in write mode. */
+ if (pWav->onWrite != NULL) {
+ return DRWAV_FALSE;
+ }
+
+ /* If there are no samples, just return DRWAV_TRUE without doing anything. */
+ if (pWav->totalPCMFrameCount == 0) {
+ return DRWAV_TRUE;
+ }
+
+ /* Make sure the sample is clamped. */
+ if (targetFrameIndex >= pWav->totalPCMFrameCount) {
+ targetFrameIndex = pWav->totalPCMFrameCount - 1;
+ }
+
+ /*
+ For compressed formats we just use a slow generic seek. If we are seeking forward we just seek forward. If we are going backwards we need
+ to seek back to the start.
+ */
+ if (drwav__is_compressed_format_tag(pWav->translatedFormatTag)) {
+ /* TODO: This can be optimized. */
+
+ /*
+ If we're seeking forward it's simple - just keep reading samples until we hit the sample we're requesting. If we're seeking backwards,
+ we first need to seek back to the start and then just do the same thing as a forward seek.
+ */
+ if (targetFrameIndex < pWav->compressed.iCurrentPCMFrame) {
+ if (!drwav_seek_to_first_pcm_frame(pWav)) {
+ return DRWAV_FALSE;
+ }
+ }
+
+ if (targetFrameIndex > pWav->compressed.iCurrentPCMFrame) {
+ drwav_uint64 offsetInFrames = targetFrameIndex - pWav->compressed.iCurrentPCMFrame;
+
+ drwav_int16 devnull[2048];
+ while (offsetInFrames > 0) {
+ drwav_uint64 framesRead = 0;
+ drwav_uint64 framesToRead = offsetInFrames;
+ if (framesToRead > drwav_countof(devnull)/pWav->channels) {
+ framesToRead = drwav_countof(devnull)/pWav->channels;
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) {
+ framesRead = drwav_read_pcm_frames_s16__msadpcm(pWav, framesToRead, devnull);
+ } else if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) {
+ framesRead = drwav_read_pcm_frames_s16__ima(pWav, framesToRead, devnull);
+ } else {
+ DRWAV_ASSERT(DRWAV_FALSE); /* If this assertion is triggered it means I've implemented a new compressed format but forgot to add a branch for it here. */
+ }
+
+ if (framesRead != framesToRead) {
+ return DRWAV_FALSE;
+ }
+
+ offsetInFrames -= framesRead;
+ }
+ }
+ } else {
+ drwav_uint64 totalSizeInBytes;
+ drwav_uint64 currentBytePos;
+ drwav_uint64 targetBytePos;
+ drwav_uint64 offset;
+
+ totalSizeInBytes = pWav->totalPCMFrameCount * drwav_get_bytes_per_pcm_frame(pWav);
+ DRWAV_ASSERT(totalSizeInBytes >= pWav->bytesRemaining);
+
+ currentBytePos = totalSizeInBytes - pWav->bytesRemaining;
+ targetBytePos = targetFrameIndex * drwav_get_bytes_per_pcm_frame(pWav);
+
+ if (currentBytePos < targetBytePos) {
+ /* Offset forwards. */
+ offset = (targetBytePos - currentBytePos);
+ } else {
+ /* Offset backwards. */
+ if (!drwav_seek_to_first_pcm_frame(pWav)) {
+ return DRWAV_FALSE;
+ }
+ offset = targetBytePos;
+ }
+
+ while (offset > 0) {
+ int offset32 = ((offset > INT_MAX) ? INT_MAX : (int)offset);
+ if (!pWav->onSeek(pWav->pUserData, offset32, drwav_seek_origin_current)) {
+ return DRWAV_FALSE;
+ }
+
+ pWav->bytesRemaining -= offset32;
+ offset -= offset32;
+ }
+ }
+
+ return DRWAV_TRUE;
+}
+
+
+DRWAV_API size_t drwav_write_raw(drwav* pWav, size_t bytesToWrite, const void* pData)
+{
+ size_t bytesWritten;
+
+ if (pWav == NULL || bytesToWrite == 0 || pData == NULL) {
+ return 0;
+ }
+
+ bytesWritten = pWav->onWrite(pWav->pUserData, pData, bytesToWrite);
+ pWav->dataChunkDataSize += bytesWritten;
+
+ return bytesWritten;
+}
+
+
+DRWAV_API drwav_uint64 drwav_write_pcm_frames_le(drwav* pWav, drwav_uint64 framesToWrite, const void* pData)
+{
+ drwav_uint64 bytesToWrite;
+ drwav_uint64 bytesWritten;
+ const drwav_uint8* pRunningData;
+
+ if (pWav == NULL || framesToWrite == 0 || pData == NULL) {
+ return 0;
+ }
+
+ bytesToWrite = ((framesToWrite * pWav->channels * pWav->bitsPerSample) / 8);
+ if (bytesToWrite > DRWAV_SIZE_MAX) {
+ return 0;
+ }
+
+ bytesWritten = 0;
+ pRunningData = (const drwav_uint8*)pData;
+
+ while (bytesToWrite > 0) {
+ size_t bytesJustWritten;
+ drwav_uint64 bytesToWriteThisIteration;
+
+ bytesToWriteThisIteration = bytesToWrite;
+ DRWAV_ASSERT(bytesToWriteThisIteration <= DRWAV_SIZE_MAX); /* <-- This is checked above. */
+
+ bytesJustWritten = drwav_write_raw(pWav, (size_t)bytesToWriteThisIteration, pRunningData);
+ if (bytesJustWritten == 0) {
+ break;
+ }
+
+ bytesToWrite -= bytesJustWritten;
+ bytesWritten += bytesJustWritten;
+ pRunningData += bytesJustWritten;
+ }
+
+ return (bytesWritten * 8) / pWav->bitsPerSample / pWav->channels;
+}
+
+DRWAV_API drwav_uint64 drwav_write_pcm_frames_be(drwav* pWav, drwav_uint64 framesToWrite, const void* pData)
+{
+ drwav_uint64 bytesToWrite;
+ drwav_uint64 bytesWritten;
+ drwav_uint32 bytesPerSample;
+ const drwav_uint8* pRunningData;
+
+ if (pWav == NULL || framesToWrite == 0 || pData == NULL) {
+ return 0;
+ }
+
+ bytesToWrite = ((framesToWrite * pWav->channels * pWav->bitsPerSample) / 8);
+ if (bytesToWrite > DRWAV_SIZE_MAX) {
+ return 0;
+ }
+
+ bytesWritten = 0;
+ pRunningData = (const drwav_uint8*)pData;
+
+ bytesPerSample = drwav_get_bytes_per_pcm_frame(pWav) / pWav->channels;
+
+ while (bytesToWrite > 0) {
+ drwav_uint8 temp[4096];
+ drwav_uint32 sampleCount;
+ size_t bytesJustWritten;
+ drwav_uint64 bytesToWriteThisIteration;
+
+ bytesToWriteThisIteration = bytesToWrite;
+ DRWAV_ASSERT(bytesToWriteThisIteration <= DRWAV_SIZE_MAX); /* <-- This is checked above. */
+
+ /*
+ WAV files are always little-endian. We need to byte swap on big-endian architectures. Since our input buffer is read-only we need
+ to use an intermediary buffer for the conversion.
+ */
+ sampleCount = sizeof(temp)/bytesPerSample;
+
+ if (bytesToWriteThisIteration > ((drwav_uint64)sampleCount)*bytesPerSample) {
+ bytesToWriteThisIteration = ((drwav_uint64)sampleCount)*bytesPerSample;
+ }
+
+ DRWAV_COPY_MEMORY(temp, pRunningData, (size_t)bytesToWriteThisIteration);
+ drwav__bswap_samples(temp, sampleCount, bytesPerSample, pWav->translatedFormatTag);
+
+ bytesJustWritten = drwav_write_raw(pWav, (size_t)bytesToWriteThisIteration, temp);
+ if (bytesJustWritten == 0) {
+ break;
+ }
+
+ bytesToWrite -= bytesJustWritten;
+ bytesWritten += bytesJustWritten;
+ pRunningData += bytesJustWritten;
+ }
+
+ return (bytesWritten * 8) / pWav->bitsPerSample / pWav->channels;
+}
+
+DRWAV_API drwav_uint64 drwav_write_pcm_frames(drwav* pWav, drwav_uint64 framesToWrite, const void* pData)
+{
+ if (drwav__is_little_endian()) {
+ return drwav_write_pcm_frames_le(pWav, framesToWrite, pData);
+ } else {
+ return drwav_write_pcm_frames_be(pWav, framesToWrite, pData);
+ }
+}
+
+
+static drwav_uint64 drwav_read_pcm_frames_s16__msadpcm(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut)
+{
+ drwav_uint64 totalFramesRead = 0;
+
+ DRWAV_ASSERT(pWav != NULL);
+ DRWAV_ASSERT(framesToRead > 0);
+
+ /* TODO: Lots of room for optimization here. */
+
+ while (framesToRead > 0 && pWav->compressed.iCurrentPCMFrame < pWav->totalPCMFrameCount) {
+ /* If there are no cached frames we need to load a new block. */
+ if (pWav->msadpcm.cachedFrameCount == 0 && pWav->msadpcm.bytesRemainingInBlock == 0) {
+ if (pWav->channels == 1) {
+ /* Mono. */
+ drwav_uint8 header[7];
+ if (pWav->onRead(pWav->pUserData, header, sizeof(header)) != sizeof(header)) {
+ return totalFramesRead;
+ }
+ pWav->msadpcm.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header);
+
+ pWav->msadpcm.predictor[0] = header[0];
+ pWav->msadpcm.delta[0] = drwav__bytes_to_s16(header + 1);
+ pWav->msadpcm.prevFrames[0][1] = (drwav_int32)drwav__bytes_to_s16(header + 3);
+ pWav->msadpcm.prevFrames[0][0] = (drwav_int32)drwav__bytes_to_s16(header + 5);
+ pWav->msadpcm.cachedFrames[2] = pWav->msadpcm.prevFrames[0][0];
+ pWav->msadpcm.cachedFrames[3] = pWav->msadpcm.prevFrames[0][1];
+ pWav->msadpcm.cachedFrameCount = 2;
+ } else {
+ /* Stereo. */
+ drwav_uint8 header[14];
+ if (pWav->onRead(pWav->pUserData, header, sizeof(header)) != sizeof(header)) {
+ return totalFramesRead;
+ }
+ pWav->msadpcm.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header);
+
+ pWav->msadpcm.predictor[0] = header[0];
+ pWav->msadpcm.predictor[1] = header[1];
+ pWav->msadpcm.delta[0] = drwav__bytes_to_s16(header + 2);
+ pWav->msadpcm.delta[1] = drwav__bytes_to_s16(header + 4);
+ pWav->msadpcm.prevFrames[0][1] = (drwav_int32)drwav__bytes_to_s16(header + 6);
+ pWav->msadpcm.prevFrames[1][1] = (drwav_int32)drwav__bytes_to_s16(header + 8);
+ pWav->msadpcm.prevFrames[0][0] = (drwav_int32)drwav__bytes_to_s16(header + 10);
+ pWav->msadpcm.prevFrames[1][0] = (drwav_int32)drwav__bytes_to_s16(header + 12);
+
+ pWav->msadpcm.cachedFrames[0] = pWav->msadpcm.prevFrames[0][0];
+ pWav->msadpcm.cachedFrames[1] = pWav->msadpcm.prevFrames[1][0];
+ pWav->msadpcm.cachedFrames[2] = pWav->msadpcm.prevFrames[0][1];
+ pWav->msadpcm.cachedFrames[3] = pWav->msadpcm.prevFrames[1][1];
+ pWav->msadpcm.cachedFrameCount = 2;
+ }
+ }
+
+ /* Output anything that's cached. */
+ while (framesToRead > 0 && pWav->msadpcm.cachedFrameCount > 0 && pWav->compressed.iCurrentPCMFrame < pWav->totalPCMFrameCount) {
+ if (pBufferOut != NULL) {
+ drwav_uint32 iSample = 0;
+ for (iSample = 0; iSample < pWav->channels; iSample += 1) {
+ pBufferOut[iSample] = (drwav_int16)pWav->msadpcm.cachedFrames[(drwav_countof(pWav->msadpcm.cachedFrames) - (pWav->msadpcm.cachedFrameCount*pWav->channels)) + iSample];
+ }
+
+ pBufferOut += pWav->channels;
+ }
+
+ framesToRead -= 1;
+ totalFramesRead += 1;
+ pWav->compressed.iCurrentPCMFrame += 1;
+ pWav->msadpcm.cachedFrameCount -= 1;
+ }
+
+ if (framesToRead == 0) {
+ return totalFramesRead;
+ }
+
+
+ /*
+ If there's nothing left in the cache, just go ahead and load more. If there's nothing left to load in the current block we just continue to the next
+ loop iteration which will trigger the loading of a new block.
+ */
+ if (pWav->msadpcm.cachedFrameCount == 0) {
+ if (pWav->msadpcm.bytesRemainingInBlock == 0) {
+ continue;
+ } else {
+ static drwav_int32 adaptationTable[] = {
+ 230, 230, 230, 230, 307, 409, 512, 614,
+ 768, 614, 512, 409, 307, 230, 230, 230
+ };
+ static drwav_int32 coeff1Table[] = { 256, 512, 0, 192, 240, 460, 392 };
+ static drwav_int32 coeff2Table[] = { 0, -256, 0, 64, 0, -208, -232 };
+
+ drwav_uint8 nibbles;
+ drwav_int32 nibble0;
+ drwav_int32 nibble1;
+
+ if (pWav->onRead(pWav->pUserData, &nibbles, 1) != 1) {
+ return totalFramesRead;
+ }
+ pWav->msadpcm.bytesRemainingInBlock -= 1;
+
+ /* TODO: Optimize away these if statements. */
+ nibble0 = ((nibbles & 0xF0) >> 4); if ((nibbles & 0x80)) { nibble0 |= 0xFFFFFFF0UL; }
+ nibble1 = ((nibbles & 0x0F) >> 0); if ((nibbles & 0x08)) { nibble1 |= 0xFFFFFFF0UL; }
+
+ if (pWav->channels == 1) {
+ /* Mono. */
+ drwav_int32 newSample0;
+ drwav_int32 newSample1;
+
+ newSample0 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8;
+ newSample0 += nibble0 * pWav->msadpcm.delta[0];
+ newSample0 = drwav_clamp(newSample0, -32768, 32767);
+
+ pWav->msadpcm.delta[0] = (adaptationTable[((nibbles & 0xF0) >> 4)] * pWav->msadpcm.delta[0]) >> 8;
+ if (pWav->msadpcm.delta[0] < 16) {
+ pWav->msadpcm.delta[0] = 16;
+ }
+
+ pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1];
+ pWav->msadpcm.prevFrames[0][1] = newSample0;
+
+
+ newSample1 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8;
+ newSample1 += nibble1 * pWav->msadpcm.delta[0];
+ newSample1 = drwav_clamp(newSample1, -32768, 32767);
+
+ pWav->msadpcm.delta[0] = (adaptationTable[((nibbles & 0x0F) >> 0)] * pWav->msadpcm.delta[0]) >> 8;
+ if (pWav->msadpcm.delta[0] < 16) {
+ pWav->msadpcm.delta[0] = 16;
+ }
+
+ pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1];
+ pWav->msadpcm.prevFrames[0][1] = newSample1;
+
+
+ pWav->msadpcm.cachedFrames[2] = newSample0;
+ pWav->msadpcm.cachedFrames[3] = newSample1;
+ pWav->msadpcm.cachedFrameCount = 2;
+ } else {
+ /* Stereo. */
+ drwav_int32 newSample0;
+ drwav_int32 newSample1;
+
+ /* Left. */
+ newSample0 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8;
+ newSample0 += nibble0 * pWav->msadpcm.delta[0];
+ newSample0 = drwav_clamp(newSample0, -32768, 32767);
+
+ pWav->msadpcm.delta[0] = (adaptationTable[((nibbles & 0xF0) >> 4)] * pWav->msadpcm.delta[0]) >> 8;
+ if (pWav->msadpcm.delta[0] < 16) {
+ pWav->msadpcm.delta[0] = 16;
+ }
+
+ pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1];
+ pWav->msadpcm.prevFrames[0][1] = newSample0;
+
+
+ /* Right. */
+ newSample1 = ((pWav->msadpcm.prevFrames[1][1] * coeff1Table[pWav->msadpcm.predictor[1]]) + (pWav->msadpcm.prevFrames[1][0] * coeff2Table[pWav->msadpcm.predictor[1]])) >> 8;
+ newSample1 += nibble1 * pWav->msadpcm.delta[1];
+ newSample1 = drwav_clamp(newSample1, -32768, 32767);
+
+ pWav->msadpcm.delta[1] = (adaptationTable[((nibbles & 0x0F) >> 0)] * pWav->msadpcm.delta[1]) >> 8;
+ if (pWav->msadpcm.delta[1] < 16) {
+ pWav->msadpcm.delta[1] = 16;
+ }
+
+ pWav->msadpcm.prevFrames[1][0] = pWav->msadpcm.prevFrames[1][1];
+ pWav->msadpcm.prevFrames[1][1] = newSample1;
+
+ pWav->msadpcm.cachedFrames[2] = newSample0;
+ pWav->msadpcm.cachedFrames[3] = newSample1;
+ pWav->msadpcm.cachedFrameCount = 1;
+ }
+ }
+ }
+ }
+
+ return totalFramesRead;
+}
+
+
+static drwav_uint64 drwav_read_pcm_frames_s16__ima(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut)
+{
+ drwav_uint64 totalFramesRead = 0;
+ drwav_uint32 iChannel;
+
+ static drwav_int32 indexTable[16] = {
+ -1, -1, -1, -1, 2, 4, 6, 8,
+ -1, -1, -1, -1, 2, 4, 6, 8
+ };
+
+ static drwav_int32 stepTable[89] = {
+ 7, 8, 9, 10, 11, 12, 13, 14, 16, 17,
+ 19, 21, 23, 25, 28, 31, 34, 37, 41, 45,
+ 50, 55, 60, 66, 73, 80, 88, 97, 107, 118,
+ 130, 143, 157, 173, 190, 209, 230, 253, 279, 307,
+ 337, 371, 408, 449, 494, 544, 598, 658, 724, 796,
+ 876, 963, 1060, 1166, 1282, 1411, 1552, 1707, 1878, 2066,
+ 2272, 2499, 2749, 3024, 3327, 3660, 4026, 4428, 4871, 5358,
+ 5894, 6484, 7132, 7845, 8630, 9493, 10442, 11487, 12635, 13899,
+ 15289, 16818, 18500, 20350, 22385, 24623, 27086, 29794, 32767
+ };
+
+ DRWAV_ASSERT(pWav != NULL);
+ DRWAV_ASSERT(framesToRead > 0);
+
+ /* TODO: Lots of room for optimization here. */
+
+ while (framesToRead > 0 && pWav->compressed.iCurrentPCMFrame < pWav->totalPCMFrameCount) {
+ /* If there are no cached samples we need to load a new block. */
+ if (pWav->ima.cachedFrameCount == 0 && pWav->ima.bytesRemainingInBlock == 0) {
+ if (pWav->channels == 1) {
+ /* Mono. */
+ drwav_uint8 header[4];
+ if (pWav->onRead(pWav->pUserData, header, sizeof(header)) != sizeof(header)) {
+ return totalFramesRead;
+ }
+ pWav->ima.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header);
+
+ if (header[2] >= drwav_countof(stepTable)) {
+ pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, drwav_seek_origin_current);
+ pWav->ima.bytesRemainingInBlock = 0;
+ return totalFramesRead; /* Invalid data. */
+ }
+
+ pWav->ima.predictor[0] = drwav__bytes_to_s16(header + 0);
+ pWav->ima.stepIndex[0] = header[2];
+ pWav->ima.cachedFrames[drwav_countof(pWav->ima.cachedFrames) - 1] = pWav->ima.predictor[0];
+ pWav->ima.cachedFrameCount = 1;
+ } else {
+ /* Stereo. */
+ drwav_uint8 header[8];
+ if (pWav->onRead(pWav->pUserData, header, sizeof(header)) != sizeof(header)) {
+ return totalFramesRead;
+ }
+ pWav->ima.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header);
+
+ if (header[2] >= drwav_countof(stepTable) || header[6] >= drwav_countof(stepTable)) {
+ pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, drwav_seek_origin_current);
+ pWav->ima.bytesRemainingInBlock = 0;
+ return totalFramesRead; /* Invalid data. */
+ }
+
+ pWav->ima.predictor[0] = drwav__bytes_to_s16(header + 0);
+ pWav->ima.stepIndex[0] = header[2];
+ pWav->ima.predictor[1] = drwav__bytes_to_s16(header + 4);
+ pWav->ima.stepIndex[1] = header[6];
+
+ pWav->ima.cachedFrames[drwav_countof(pWav->ima.cachedFrames) - 2] = pWav->ima.predictor[0];
+ pWav->ima.cachedFrames[drwav_countof(pWav->ima.cachedFrames) - 1] = pWav->ima.predictor[1];
+ pWav->ima.cachedFrameCount = 1;
+ }
+ }
+
+ /* Output anything that's cached. */
+ while (framesToRead > 0 && pWav->ima.cachedFrameCount > 0 && pWav->compressed.iCurrentPCMFrame < pWav->totalPCMFrameCount) {
+ if (pBufferOut != NULL) {
+ drwav_uint32 iSample;
+ for (iSample = 0; iSample < pWav->channels; iSample += 1) {
+ pBufferOut[iSample] = (drwav_int16)pWav->ima.cachedFrames[(drwav_countof(pWav->ima.cachedFrames) - (pWav->ima.cachedFrameCount*pWav->channels)) + iSample];
+ }
+ pBufferOut += pWav->channels;
+ }
+
+ framesToRead -= 1;
+ totalFramesRead += 1;
+ pWav->compressed.iCurrentPCMFrame += 1;
+ pWav->ima.cachedFrameCount -= 1;
+ }
+
+ if (framesToRead == 0) {
+ return totalFramesRead;
+ }
+
+ /*
+ If there's nothing left in the cache, just go ahead and load more. If there's nothing left to load in the current block we just continue to the next
+ loop iteration which will trigger the loading of a new block.
+ */
+ if (pWav->ima.cachedFrameCount == 0) {
+ if (pWav->ima.bytesRemainingInBlock == 0) {
+ continue;
+ } else {
+ /*
+ From what I can tell with stereo streams, it looks like every 4 bytes (8 samples) is for one channel. So it goes 4 bytes for the
+ left channel, 4 bytes for the right channel.
+ */
+ pWav->ima.cachedFrameCount = 8;
+ for (iChannel = 0; iChannel < pWav->channels; ++iChannel) {
+ drwav_uint32 iByte;
+ drwav_uint8 nibbles[4];
+ if (pWav->onRead(pWav->pUserData, &nibbles, 4) != 4) {
+ pWav->ima.cachedFrameCount = 0;
+ return totalFramesRead;
+ }
+ pWav->ima.bytesRemainingInBlock -= 4;
+
+ for (iByte = 0; iByte < 4; ++iByte) {
+ drwav_uint8 nibble0 = ((nibbles[iByte] & 0x0F) >> 0);
+ drwav_uint8 nibble1 = ((nibbles[iByte] & 0xF0) >> 4);
+
+ drwav_int32 step = stepTable[pWav->ima.stepIndex[iChannel]];
+ drwav_int32 predictor = pWav->ima.predictor[iChannel];
+
+ drwav_int32 diff = step >> 3;
+ if (nibble0 & 1) diff += step >> 2;
+ if (nibble0 & 2) diff += step >> 1;
+ if (nibble0 & 4) diff += step;
+ if (nibble0 & 8) diff = -diff;
+
+ predictor = drwav_clamp(predictor + diff, -32768, 32767);
+ pWav->ima.predictor[iChannel] = predictor;
+ pWav->ima.stepIndex[iChannel] = drwav_clamp(pWav->ima.stepIndex[iChannel] + indexTable[nibble0], 0, (drwav_int32)drwav_countof(stepTable)-1);
+ pWav->ima.cachedFrames[(drwav_countof(pWav->ima.cachedFrames) - (pWav->ima.cachedFrameCount*pWav->channels)) + (iByte*2+0)*pWav->channels + iChannel] = predictor;
+
+
+ step = stepTable[pWav->ima.stepIndex[iChannel]];
+ predictor = pWav->ima.predictor[iChannel];
+
+ diff = step >> 3;
+ if (nibble1 & 1) diff += step >> 2;
+ if (nibble1 & 2) diff += step >> 1;
+ if (nibble1 & 4) diff += step;
+ if (nibble1 & 8) diff = -diff;
+
+ predictor = drwav_clamp(predictor + diff, -32768, 32767);
+ pWav->ima.predictor[iChannel] = predictor;
+ pWav->ima.stepIndex[iChannel] = drwav_clamp(pWav->ima.stepIndex[iChannel] + indexTable[nibble1], 0, (drwav_int32)drwav_countof(stepTable)-1);
+ pWav->ima.cachedFrames[(drwav_countof(pWav->ima.cachedFrames) - (pWav->ima.cachedFrameCount*pWav->channels)) + (iByte*2+1)*pWav->channels + iChannel] = predictor;
+ }
+ }
+ }
+ }
+ }
+
+ return totalFramesRead;
+}
+
+
+#ifndef DR_WAV_NO_CONVERSION_API
+static unsigned short g_drwavAlawTable[256] = {
+ 0xEA80, 0xEB80, 0xE880, 0xE980, 0xEE80, 0xEF80, 0xEC80, 0xED80, 0xE280, 0xE380, 0xE080, 0xE180, 0xE680, 0xE780, 0xE480, 0xE580,
+ 0xF540, 0xF5C0, 0xF440, 0xF4C0, 0xF740, 0xF7C0, 0xF640, 0xF6C0, 0xF140, 0xF1C0, 0xF040, 0xF0C0, 0xF340, 0xF3C0, 0xF240, 0xF2C0,
+ 0xAA00, 0xAE00, 0xA200, 0xA600, 0xBA00, 0xBE00, 0xB200, 0xB600, 0x8A00, 0x8E00, 0x8200, 0x8600, 0x9A00, 0x9E00, 0x9200, 0x9600,
+ 0xD500, 0xD700, 0xD100, 0xD300, 0xDD00, 0xDF00, 0xD900, 0xDB00, 0xC500, 0xC700, 0xC100, 0xC300, 0xCD00, 0xCF00, 0xC900, 0xCB00,
+ 0xFEA8, 0xFEB8, 0xFE88, 0xFE98, 0xFEE8, 0xFEF8, 0xFEC8, 0xFED8, 0xFE28, 0xFE38, 0xFE08, 0xFE18, 0xFE68, 0xFE78, 0xFE48, 0xFE58,
+ 0xFFA8, 0xFFB8, 0xFF88, 0xFF98, 0xFFE8, 0xFFF8, 0xFFC8, 0xFFD8, 0xFF28, 0xFF38, 0xFF08, 0xFF18, 0xFF68, 0xFF78, 0xFF48, 0xFF58,
+ 0xFAA0, 0xFAE0, 0xFA20, 0xFA60, 0xFBA0, 0xFBE0, 0xFB20, 0xFB60, 0xF8A0, 0xF8E0, 0xF820, 0xF860, 0xF9A0, 0xF9E0, 0xF920, 0xF960,
+ 0xFD50, 0xFD70, 0xFD10, 0xFD30, 0xFDD0, 0xFDF0, 0xFD90, 0xFDB0, 0xFC50, 0xFC70, 0xFC10, 0xFC30, 0xFCD0, 0xFCF0, 0xFC90, 0xFCB0,
+ 0x1580, 0x1480, 0x1780, 0x1680, 0x1180, 0x1080, 0x1380, 0x1280, 0x1D80, 0x1C80, 0x1F80, 0x1E80, 0x1980, 0x1880, 0x1B80, 0x1A80,
+ 0x0AC0, 0x0A40, 0x0BC0, 0x0B40, 0x08C0, 0x0840, 0x09C0, 0x0940, 0x0EC0, 0x0E40, 0x0FC0, 0x0F40, 0x0CC0, 0x0C40, 0x0DC0, 0x0D40,
+ 0x5600, 0x5200, 0x5E00, 0x5A00, 0x4600, 0x4200, 0x4E00, 0x4A00, 0x7600, 0x7200, 0x7E00, 0x7A00, 0x6600, 0x6200, 0x6E00, 0x6A00,
+ 0x2B00, 0x2900, 0x2F00, 0x2D00, 0x2300, 0x2100, 0x2700, 0x2500, 0x3B00, 0x3900, 0x3F00, 0x3D00, 0x3300, 0x3100, 0x3700, 0x3500,
+ 0x0158, 0x0148, 0x0178, 0x0168, 0x0118, 0x0108, 0x0138, 0x0128, 0x01D8, 0x01C8, 0x01F8, 0x01E8, 0x0198, 0x0188, 0x01B8, 0x01A8,
+ 0x0058, 0x0048, 0x0078, 0x0068, 0x0018, 0x0008, 0x0038, 0x0028, 0x00D8, 0x00C8, 0x00F8, 0x00E8, 0x0098, 0x0088, 0x00B8, 0x00A8,
+ 0x0560, 0x0520, 0x05E0, 0x05A0, 0x0460, 0x0420, 0x04E0, 0x04A0, 0x0760, 0x0720, 0x07E0, 0x07A0, 0x0660, 0x0620, 0x06E0, 0x06A0,
+ 0x02B0, 0x0290, 0x02F0, 0x02D0, 0x0230, 0x0210, 0x0270, 0x0250, 0x03B0, 0x0390, 0x03F0, 0x03D0, 0x0330, 0x0310, 0x0370, 0x0350
+};
+
+static unsigned short g_drwavMulawTable[256] = {
+ 0x8284, 0x8684, 0x8A84, 0x8E84, 0x9284, 0x9684, 0x9A84, 0x9E84, 0xA284, 0xA684, 0xAA84, 0xAE84, 0xB284, 0xB684, 0xBA84, 0xBE84,
+ 0xC184, 0xC384, 0xC584, 0xC784, 0xC984, 0xCB84, 0xCD84, 0xCF84, 0xD184, 0xD384, 0xD584, 0xD784, 0xD984, 0xDB84, 0xDD84, 0xDF84,
+ 0xE104, 0xE204, 0xE304, 0xE404, 0xE504, 0xE604, 0xE704, 0xE804, 0xE904, 0xEA04, 0xEB04, 0xEC04, 0xED04, 0xEE04, 0xEF04, 0xF004,
+ 0xF0C4, 0xF144, 0xF1C4, 0xF244, 0xF2C4, 0xF344, 0xF3C4, 0xF444, 0xF4C4, 0xF544, 0xF5C4, 0xF644, 0xF6C4, 0xF744, 0xF7C4, 0xF844,
+ 0xF8A4, 0xF8E4, 0xF924, 0xF964, 0xF9A4, 0xF9E4, 0xFA24, 0xFA64, 0xFAA4, 0xFAE4, 0xFB24, 0xFB64, 0xFBA4, 0xFBE4, 0xFC24, 0xFC64,
+ 0xFC94, 0xFCB4, 0xFCD4, 0xFCF4, 0xFD14, 0xFD34, 0xFD54, 0xFD74, 0xFD94, 0xFDB4, 0xFDD4, 0xFDF4, 0xFE14, 0xFE34, 0xFE54, 0xFE74,
+ 0xFE8C, 0xFE9C, 0xFEAC, 0xFEBC, 0xFECC, 0xFEDC, 0xFEEC, 0xFEFC, 0xFF0C, 0xFF1C, 0xFF2C, 0xFF3C, 0xFF4C, 0xFF5C, 0xFF6C, 0xFF7C,
+ 0xFF88, 0xFF90, 0xFF98, 0xFFA0, 0xFFA8, 0xFFB0, 0xFFB8, 0xFFC0, 0xFFC8, 0xFFD0, 0xFFD8, 0xFFE0, 0xFFE8, 0xFFF0, 0xFFF8, 0x0000,
+ 0x7D7C, 0x797C, 0x757C, 0x717C, 0x6D7C, 0x697C, 0x657C, 0x617C, 0x5D7C, 0x597C, 0x557C, 0x517C, 0x4D7C, 0x497C, 0x457C, 0x417C,
+ 0x3E7C, 0x3C7C, 0x3A7C, 0x387C, 0x367C, 0x347C, 0x327C, 0x307C, 0x2E7C, 0x2C7C, 0x2A7C, 0x287C, 0x267C, 0x247C, 0x227C, 0x207C,
+ 0x1EFC, 0x1DFC, 0x1CFC, 0x1BFC, 0x1AFC, 0x19FC, 0x18FC, 0x17FC, 0x16FC, 0x15FC, 0x14FC, 0x13FC, 0x12FC, 0x11FC, 0x10FC, 0x0FFC,
+ 0x0F3C, 0x0EBC, 0x0E3C, 0x0DBC, 0x0D3C, 0x0CBC, 0x0C3C, 0x0BBC, 0x0B3C, 0x0ABC, 0x0A3C, 0x09BC, 0x093C, 0x08BC, 0x083C, 0x07BC,
+ 0x075C, 0x071C, 0x06DC, 0x069C, 0x065C, 0x061C, 0x05DC, 0x059C, 0x055C, 0x051C, 0x04DC, 0x049C, 0x045C, 0x041C, 0x03DC, 0x039C,
+ 0x036C, 0x034C, 0x032C, 0x030C, 0x02EC, 0x02CC, 0x02AC, 0x028C, 0x026C, 0x024C, 0x022C, 0x020C, 0x01EC, 0x01CC, 0x01AC, 0x018C,
+ 0x0174, 0x0164, 0x0154, 0x0144, 0x0134, 0x0124, 0x0114, 0x0104, 0x00F4, 0x00E4, 0x00D4, 0x00C4, 0x00B4, 0x00A4, 0x0094, 0x0084,
+ 0x0078, 0x0070, 0x0068, 0x0060, 0x0058, 0x0050, 0x0048, 0x0040, 0x0038, 0x0030, 0x0028, 0x0020, 0x0018, 0x0010, 0x0008, 0x0000
+};
+
+static DRWAV_INLINE drwav_int16 drwav__alaw_to_s16(drwav_uint8 sampleIn)
+{
+ return (short)g_drwavAlawTable[sampleIn];
+}
+
+static DRWAV_INLINE drwav_int16 drwav__mulaw_to_s16(drwav_uint8 sampleIn)
+{
+ return (short)g_drwavMulawTable[sampleIn];
+}
+
+
+
+static void drwav__pcm_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample)
+{
+ unsigned int i;
+
+ /* Special case for 8-bit sample data because it's treated as unsigned. */
+ if (bytesPerSample == 1) {
+ drwav_u8_to_s16(pOut, pIn, totalSampleCount);
+ return;
+ }
+
+
+ /* Slightly more optimal implementation for common formats. */
+ if (bytesPerSample == 2) {
+ for (i = 0; i < totalSampleCount; ++i) {
+ *pOut++ = ((const drwav_int16*)pIn)[i];
+ }
+ return;
+ }
+ if (bytesPerSample == 3) {
+ drwav_s24_to_s16(pOut, pIn, totalSampleCount);
+ return;
+ }
+ if (bytesPerSample == 4) {
+ drwav_s32_to_s16(pOut, (const drwav_int32*)pIn, totalSampleCount);
+ return;
+ }
+
+
+ /* Anything more than 64 bits per sample is not supported. */
+ if (bytesPerSample > 8) {
+ DRWAV_ZERO_MEMORY(pOut, totalSampleCount * sizeof(*pOut));
+ return;
+ }
+
+
+ /* Generic, slow converter. */
+ for (i = 0; i < totalSampleCount; ++i) {
+ drwav_uint64 sample = 0;
+ unsigned int shift = (8 - bytesPerSample) * 8;
+
+ unsigned int j;
+ for (j = 0; j < bytesPerSample; j += 1) {
+ DRWAV_ASSERT(j < 8);
+ sample |= (drwav_uint64)(pIn[j]) << shift;
+ shift += 8;
+ }
+
+ pIn += j;
+ *pOut++ = (drwav_int16)((drwav_int64)sample >> 48);
+ }
+}
+
+static void drwav__ieee_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample)
+{
+ if (bytesPerSample == 4) {
+ drwav_f32_to_s16(pOut, (const float*)pIn, totalSampleCount);
+ return;
+ } else if (bytesPerSample == 8) {
+ drwav_f64_to_s16(pOut, (const double*)pIn, totalSampleCount);
+ return;
+ } else {
+ /* Only supporting 32- and 64-bit float. Output silence in all other cases. Contributions welcome for 16-bit float. */
+ DRWAV_ZERO_MEMORY(pOut, totalSampleCount * sizeof(*pOut));
+ return;
+ }
+}
+
+static drwav_uint64 drwav_read_pcm_frames_s16__pcm(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut)
+{
+ drwav_uint32 bytesPerFrame;
+ drwav_uint64 totalFramesRead;
+ drwav_uint8 sampleData[4096];
+
+ /* Fast path. */
+ if ((pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM && pWav->bitsPerSample == 16) || pBufferOut == NULL) {
+ return drwav_read_pcm_frames(pWav, framesToRead, pBufferOut);
+ }
+
+ bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
+ if (bytesPerFrame == 0) {
+ return 0;
+ }
+
+ totalFramesRead = 0;
+
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav__pcm_to_s16(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels);
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+static drwav_uint64 drwav_read_pcm_frames_s16__ieee(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut)
+{
+ drwav_uint64 totalFramesRead;
+ drwav_uint8 sampleData[4096];
+ drwav_uint32 bytesPerFrame;
+
+ if (pBufferOut == NULL) {
+ return drwav_read_pcm_frames(pWav, framesToRead, NULL);
+ }
+
+ bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
+ if (bytesPerFrame == 0) {
+ return 0;
+ }
+
+ totalFramesRead = 0;
+
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav__ieee_to_s16(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels);
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+static drwav_uint64 drwav_read_pcm_frames_s16__alaw(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut)
+{
+ drwav_uint64 totalFramesRead;
+ drwav_uint8 sampleData[4096];
+ drwav_uint32 bytesPerFrame;
+
+ if (pBufferOut == NULL) {
+ return drwav_read_pcm_frames(pWav, framesToRead, NULL);
+ }
+
+ bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
+ if (bytesPerFrame == 0) {
+ return 0;
+ }
+
+ totalFramesRead = 0;
+
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav_alaw_to_s16(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels));
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+static drwav_uint64 drwav_read_pcm_frames_s16__mulaw(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut)
+{
+ drwav_uint64 totalFramesRead;
+ drwav_uint8 sampleData[4096];
+ drwav_uint32 bytesPerFrame;
+
+ if (pBufferOut == NULL) {
+ return drwav_read_pcm_frames(pWav, framesToRead, NULL);
+ }
+
+ bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
+ if (bytesPerFrame == 0) {
+ return 0;
+ }
+
+ totalFramesRead = 0;
+
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav_mulaw_to_s16(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels));
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut)
+{
+ if (pWav == NULL || framesToRead == 0) {
+ return 0;
+ }
+
+ if (pBufferOut == NULL) {
+ return drwav_read_pcm_frames(pWav, framesToRead, NULL);
+ }
+
+ /* Don't try to read more samples than can potentially fit in the output buffer. */
+ if (framesToRead * pWav->channels * sizeof(drwav_int16) > DRWAV_SIZE_MAX) {
+ framesToRead = DRWAV_SIZE_MAX / sizeof(drwav_int16) / pWav->channels;
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM) {
+ return drwav_read_pcm_frames_s16__pcm(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_IEEE_FLOAT) {
+ return drwav_read_pcm_frames_s16__ieee(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ALAW) {
+ return drwav_read_pcm_frames_s16__alaw(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_MULAW) {
+ return drwav_read_pcm_frames_s16__mulaw(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) {
+ return drwav_read_pcm_frames_s16__msadpcm(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) {
+ return drwav_read_pcm_frames_s16__ima(pWav, framesToRead, pBufferOut);
+ }
+
+ return 0;
+}
+
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16le(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut)
+{
+ drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, framesToRead, pBufferOut);
+ if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_FALSE) {
+ drwav__bswap_samples_s16(pBufferOut, framesRead*pWav->channels);
+ }
+
+ return framesRead;
+}
+
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16be(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut)
+{
+ drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, framesToRead, pBufferOut);
+ if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_TRUE) {
+ drwav__bswap_samples_s16(pBufferOut, framesRead*pWav->channels);
+ }
+
+ return framesRead;
+}
+
+
+DRWAV_API void drwav_u8_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount)
+{
+ int r;
+ size_t i;
+ for (i = 0; i < sampleCount; ++i) {
+ int x = pIn[i];
+ r = x << 8;
+ r = r - 32768;
+ pOut[i] = (short)r;
+ }
+}
+
+DRWAV_API void drwav_s24_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount)
+{
+ int r;
+ size_t i;
+ for (i = 0; i < sampleCount; ++i) {
+ int x = ((int)(((unsigned int)(((const drwav_uint8*)pIn)[i*3+0]) << 8) | ((unsigned int)(((const drwav_uint8*)pIn)[i*3+1]) << 16) | ((unsigned int)(((const drwav_uint8*)pIn)[i*3+2])) << 24)) >> 8;
+ r = x >> 8;
+ pOut[i] = (short)r;
+ }
+}
+
+DRWAV_API void drwav_s32_to_s16(drwav_int16* pOut, const drwav_int32* pIn, size_t sampleCount)
+{
+ int r;
+ size_t i;
+ for (i = 0; i < sampleCount; ++i) {
+ int x = pIn[i];
+ r = x >> 16;
+ pOut[i] = (short)r;
+ }
+}
+
+DRWAV_API void drwav_f32_to_s16(drwav_int16* pOut, const float* pIn, size_t sampleCount)
+{
+ int r;
+ size_t i;
+ for (i = 0; i < sampleCount; ++i) {
+ float x = pIn[i];
+ float c;
+ c = ((x < -1) ? -1 : ((x > 1) ? 1 : x));
+ c = c + 1;
+ r = (int)(c * 32767.5f);
+ r = r - 32768;
+ pOut[i] = (short)r;
+ }
+}
+
+DRWAV_API void drwav_f64_to_s16(drwav_int16* pOut, const double* pIn, size_t sampleCount)
+{
+ int r;
+ size_t i;
+ for (i = 0; i < sampleCount; ++i) {
+ double x = pIn[i];
+ double c;
+ c = ((x < -1) ? -1 : ((x > 1) ? 1 : x));
+ c = c + 1;
+ r = (int)(c * 32767.5);
+ r = r - 32768;
+ pOut[i] = (short)r;
+ }
+}
+
+DRWAV_API void drwav_alaw_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount)
+{
+ size_t i;
+ for (i = 0; i < sampleCount; ++i) {
+ pOut[i] = drwav__alaw_to_s16(pIn[i]);
+ }
+}
+
+DRWAV_API void drwav_mulaw_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount)
+{
+ size_t i;
+ for (i = 0; i < sampleCount; ++i) {
+ pOut[i] = drwav__mulaw_to_s16(pIn[i]);
+ }
+}
+
+
+
+static void drwav__pcm_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount, unsigned int bytesPerSample)
+{
+ unsigned int i;
+
+ /* Special case for 8-bit sample data because it's treated as unsigned. */
+ if (bytesPerSample == 1) {
+ drwav_u8_to_f32(pOut, pIn, sampleCount);
+ return;
+ }
+
+ /* Slightly more optimal implementation for common formats. */
+ if (bytesPerSample == 2) {
+ drwav_s16_to_f32(pOut, (const drwav_int16*)pIn, sampleCount);
+ return;
+ }
+ if (bytesPerSample == 3) {
+ drwav_s24_to_f32(pOut, pIn, sampleCount);
+ return;
+ }
+ if (bytesPerSample == 4) {
+ drwav_s32_to_f32(pOut, (const drwav_int32*)pIn, sampleCount);
+ return;
+ }
+
+
+ /* Anything more than 64 bits per sample is not supported. */
+ if (bytesPerSample > 8) {
+ DRWAV_ZERO_MEMORY(pOut, sampleCount * sizeof(*pOut));
+ return;
+ }
+
+
+ /* Generic, slow converter. */
+ for (i = 0; i < sampleCount; ++i) {
+ drwav_uint64 sample = 0;
+ unsigned int shift = (8 - bytesPerSample) * 8;
+
+ unsigned int j;
+ for (j = 0; j < bytesPerSample; j += 1) {
+ DRWAV_ASSERT(j < 8);
+ sample |= (drwav_uint64)(pIn[j]) << shift;
+ shift += 8;
+ }
+
+ pIn += j;
+ *pOut++ = (float)((drwav_int64)sample / 9223372036854775807.0);
+ }
+}
+
+static void drwav__ieee_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount, unsigned int bytesPerSample)
+{
+ if (bytesPerSample == 4) {
+ unsigned int i;
+ for (i = 0; i < sampleCount; ++i) {
+ *pOut++ = ((const float*)pIn)[i];
+ }
+ return;
+ } else if (bytesPerSample == 8) {
+ drwav_f64_to_f32(pOut, (const double*)pIn, sampleCount);
+ return;
+ } else {
+ /* Only supporting 32- and 64-bit float. Output silence in all other cases. Contributions welcome for 16-bit float. */
+ DRWAV_ZERO_MEMORY(pOut, sampleCount * sizeof(*pOut));
+ return;
+ }
+}
+
+
+static drwav_uint64 drwav_read_pcm_frames_f32__pcm(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut)
+{
+ drwav_uint64 totalFramesRead;
+ drwav_uint8 sampleData[4096];
+
+ drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
+ if (bytesPerFrame == 0) {
+ return 0;
+ }
+
+ totalFramesRead = 0;
+
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav__pcm_to_f32(pBufferOut, sampleData, (size_t)framesRead*pWav->channels, bytesPerFrame/pWav->channels);
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+static drwav_uint64 drwav_read_pcm_frames_f32__msadpcm(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut)
+{
+ /*
+ We're just going to borrow the implementation from the drwav_read_s16() since ADPCM is a little bit more complicated than other formats and I don't
+ want to duplicate that code.
+ */
+ drwav_uint64 totalFramesRead = 0;
+ drwav_int16 samples16[2048];
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, drwav_min(framesToRead, drwav_countof(samples16)/pWav->channels), samples16);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav_s16_to_f32(pBufferOut, samples16, (size_t)(framesRead*pWav->channels)); /* <-- Safe cast because we're clamping to 2048. */
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+static drwav_uint64 drwav_read_pcm_frames_f32__ima(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut)
+{
+ /*
+ We're just going to borrow the implementation from the drwav_read_s16() since IMA-ADPCM is a little bit more complicated than other formats and I don't
+ want to duplicate that code.
+ */
+ drwav_uint64 totalFramesRead = 0;
+ drwav_int16 samples16[2048];
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, drwav_min(framesToRead, drwav_countof(samples16)/pWav->channels), samples16);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav_s16_to_f32(pBufferOut, samples16, (size_t)(framesRead*pWav->channels)); /* <-- Safe cast because we're clamping to 2048. */
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+static drwav_uint64 drwav_read_pcm_frames_f32__ieee(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut)
+{
+ drwav_uint64 totalFramesRead;
+ drwav_uint8 sampleData[4096];
+ drwav_uint32 bytesPerFrame;
+
+ /* Fast path. */
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_IEEE_FLOAT && pWav->bitsPerSample == 32) {
+ return drwav_read_pcm_frames(pWav, framesToRead, pBufferOut);
+ }
+
+ bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
+ if (bytesPerFrame == 0) {
+ return 0;
+ }
+
+ totalFramesRead = 0;
+
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav__ieee_to_f32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels);
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+static drwav_uint64 drwav_read_pcm_frames_f32__alaw(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut)
+{
+ drwav_uint64 totalFramesRead;
+ drwav_uint8 sampleData[4096];
+ drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
+ if (bytesPerFrame == 0) {
+ return 0;
+ }
+
+ totalFramesRead = 0;
+
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav_alaw_to_f32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels));
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+static drwav_uint64 drwav_read_pcm_frames_f32__mulaw(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut)
+{
+ drwav_uint64 totalFramesRead;
+ drwav_uint8 sampleData[4096];
+
+ drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
+ if (bytesPerFrame == 0) {
+ return 0;
+ }
+
+ totalFramesRead = 0;
+
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav_mulaw_to_f32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels));
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut)
+{
+ if (pWav == NULL || framesToRead == 0) {
+ return 0;
+ }
+
+ if (pBufferOut == NULL) {
+ return drwav_read_pcm_frames(pWav, framesToRead, NULL);
+ }
+
+ /* Don't try to read more samples than can potentially fit in the output buffer. */
+ if (framesToRead * pWav->channels * sizeof(float) > DRWAV_SIZE_MAX) {
+ framesToRead = DRWAV_SIZE_MAX / sizeof(float) / pWav->channels;
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM) {
+ return drwav_read_pcm_frames_f32__pcm(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) {
+ return drwav_read_pcm_frames_f32__msadpcm(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_IEEE_FLOAT) {
+ return drwav_read_pcm_frames_f32__ieee(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ALAW) {
+ return drwav_read_pcm_frames_f32__alaw(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_MULAW) {
+ return drwav_read_pcm_frames_f32__mulaw(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) {
+ return drwav_read_pcm_frames_f32__ima(pWav, framesToRead, pBufferOut);
+ }
+
+ return 0;
+}
+
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32le(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut)
+{
+ drwav_uint64 framesRead = drwav_read_pcm_frames_f32(pWav, framesToRead, pBufferOut);
+ if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_FALSE) {
+ drwav__bswap_samples_f32(pBufferOut, framesRead*pWav->channels);
+ }
+
+ return framesRead;
+}
+
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32be(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut)
+{
+ drwav_uint64 framesRead = drwav_read_pcm_frames_f32(pWav, framesToRead, pBufferOut);
+ if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_TRUE) {
+ drwav__bswap_samples_f32(pBufferOut, framesRead*pWav->channels);
+ }
+
+ return framesRead;
+}
+
+
+DRWAV_API void drwav_u8_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount)
+{
+ size_t i;
+
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+#ifdef DR_WAV_LIBSNDFILE_COMPAT
+ /*
+ It appears libsndfile uses slightly different logic for the u8 -> f32 conversion to dr_wav, which in my opinion is incorrect. It appears
+ libsndfile performs the conversion something like "f32 = (u8 / 256) * 2 - 1", however I think it should be "f32 = (u8 / 255) * 2 - 1" (note
+ the divisor of 256 vs 255). I use libsndfile as a benchmark for testing, so I'm therefore leaving this block here just for my automated
+ correctness testing. This is disabled by default.
+ */
+ for (i = 0; i < sampleCount; ++i) {
+ *pOut++ = (pIn[i] / 256.0f) * 2 - 1;
+ }
+#else
+ for (i = 0; i < sampleCount; ++i) {
+ float x = pIn[i];
+ x = x * 0.00784313725490196078f; /* 0..255 to 0..2 */
+ x = x - 1; /* 0..2 to -1..1 */
+
+ *pOut++ = x;
+ }
+#endif
+}
+
+DRWAV_API void drwav_s16_to_f32(float* pOut, const drwav_int16* pIn, size_t sampleCount)
+{
+ size_t i;
+
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+ for (i = 0; i < sampleCount; ++i) {
+ *pOut++ = pIn[i] * 0.000030517578125f;
+ }
+}
+
+DRWAV_API void drwav_s24_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount)
+{
+ size_t i;
+
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+ for (i = 0; i < sampleCount; ++i) {
+ double x;
+ drwav_uint32 a = ((drwav_uint32)(pIn[i*3+0]) << 8);
+ drwav_uint32 b = ((drwav_uint32)(pIn[i*3+1]) << 16);
+ drwav_uint32 c = ((drwav_uint32)(pIn[i*3+2]) << 24);
+
+ x = (double)((drwav_int32)(a | b | c) >> 8);
+ *pOut++ = (float)(x * 0.00000011920928955078125);
+ }
+}
+
+DRWAV_API void drwav_s32_to_f32(float* pOut, const drwav_int32* pIn, size_t sampleCount)
+{
+ size_t i;
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+ for (i = 0; i < sampleCount; ++i) {
+ *pOut++ = (float)(pIn[i] / 2147483648.0);
+ }
+}
+
+DRWAV_API void drwav_f64_to_f32(float* pOut, const double* pIn, size_t sampleCount)
+{
+ size_t i;
+
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+ for (i = 0; i < sampleCount; ++i) {
+ *pOut++ = (float)pIn[i];
+ }
+}
+
+DRWAV_API void drwav_alaw_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount)
+{
+ size_t i;
+
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+ for (i = 0; i < sampleCount; ++i) {
+ *pOut++ = drwav__alaw_to_s16(pIn[i]) / 32768.0f;
+ }
+}
+
+DRWAV_API void drwav_mulaw_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount)
+{
+ size_t i;
+
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+ for (i = 0; i < sampleCount; ++i) {
+ *pOut++ = drwav__mulaw_to_s16(pIn[i]) / 32768.0f;
+ }
+}
+
+
+
+static void drwav__pcm_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample)
+{
+ unsigned int i;
+
+ /* Special case for 8-bit sample data because it's treated as unsigned. */
+ if (bytesPerSample == 1) {
+ drwav_u8_to_s32(pOut, pIn, totalSampleCount);
+ return;
+ }
+
+ /* Slightly more optimal implementation for common formats. */
+ if (bytesPerSample == 2) {
+ drwav_s16_to_s32(pOut, (const drwav_int16*)pIn, totalSampleCount);
+ return;
+ }
+ if (bytesPerSample == 3) {
+ drwav_s24_to_s32(pOut, pIn, totalSampleCount);
+ return;
+ }
+ if (bytesPerSample == 4) {
+ for (i = 0; i < totalSampleCount; ++i) {
+ *pOut++ = ((const drwav_int32*)pIn)[i];
+ }
+ return;
+ }
+
+
+ /* Anything more than 64 bits per sample is not supported. */
+ if (bytesPerSample > 8) {
+ DRWAV_ZERO_MEMORY(pOut, totalSampleCount * sizeof(*pOut));
+ return;
+ }
+
+
+ /* Generic, slow converter. */
+ for (i = 0; i < totalSampleCount; ++i) {
+ drwav_uint64 sample = 0;
+ unsigned int shift = (8 - bytesPerSample) * 8;
+
+ unsigned int j;
+ for (j = 0; j < bytesPerSample; j += 1) {
+ DRWAV_ASSERT(j < 8);
+ sample |= (drwav_uint64)(pIn[j]) << shift;
+ shift += 8;
+ }
+
+ pIn += j;
+ *pOut++ = (drwav_int32)((drwav_int64)sample >> 32);
+ }
+}
+
+static void drwav__ieee_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample)
+{
+ if (bytesPerSample == 4) {
+ drwav_f32_to_s32(pOut, (const float*)pIn, totalSampleCount);
+ return;
+ } else if (bytesPerSample == 8) {
+ drwav_f64_to_s32(pOut, (const double*)pIn, totalSampleCount);
+ return;
+ } else {
+ /* Only supporting 32- and 64-bit float. Output silence in all other cases. Contributions welcome for 16-bit float. */
+ DRWAV_ZERO_MEMORY(pOut, totalSampleCount * sizeof(*pOut));
+ return;
+ }
+}
+
+
+static drwav_uint64 drwav_read_pcm_frames_s32__pcm(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut)
+{
+ drwav_uint64 totalFramesRead;
+ drwav_uint8 sampleData[4096];
+ drwav_uint32 bytesPerFrame;
+
+ /* Fast path. */
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM && pWav->bitsPerSample == 32) {
+ return drwav_read_pcm_frames(pWav, framesToRead, pBufferOut);
+ }
+
+ bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
+ if (bytesPerFrame == 0) {
+ return 0;
+ }
+
+ totalFramesRead = 0;
+
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav__pcm_to_s32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels);
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+static drwav_uint64 drwav_read_pcm_frames_s32__msadpcm(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut)
+{
+ /*
+ We're just going to borrow the implementation from the drwav_read_s16() since ADPCM is a little bit more complicated than other formats and I don't
+ want to duplicate that code.
+ */
+ drwav_uint64 totalFramesRead = 0;
+ drwav_int16 samples16[2048];
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, drwav_min(framesToRead, drwav_countof(samples16)/pWav->channels), samples16);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav_s16_to_s32(pBufferOut, samples16, (size_t)(framesRead*pWav->channels)); /* <-- Safe cast because we're clamping to 2048. */
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+static drwav_uint64 drwav_read_pcm_frames_s32__ima(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut)
+{
+ /*
+ We're just going to borrow the implementation from the drwav_read_s16() since IMA-ADPCM is a little bit more complicated than other formats and I don't
+ want to duplicate that code.
+ */
+ drwav_uint64 totalFramesRead = 0;
+ drwav_int16 samples16[2048];
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, drwav_min(framesToRead, drwav_countof(samples16)/pWav->channels), samples16);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav_s16_to_s32(pBufferOut, samples16, (size_t)(framesRead*pWav->channels)); /* <-- Safe cast because we're clamping to 2048. */
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+static drwav_uint64 drwav_read_pcm_frames_s32__ieee(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut)
+{
+ drwav_uint64 totalFramesRead;
+ drwav_uint8 sampleData[4096];
+
+ drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
+ if (bytesPerFrame == 0) {
+ return 0;
+ }
+
+ totalFramesRead = 0;
+
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav__ieee_to_s32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels);
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+static drwav_uint64 drwav_read_pcm_frames_s32__alaw(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut)
+{
+ drwav_uint64 totalFramesRead;
+ drwav_uint8 sampleData[4096];
+
+ drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
+ if (bytesPerFrame == 0) {
+ return 0;
+ }
+
+ totalFramesRead = 0;
+
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav_alaw_to_s32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels));
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+static drwav_uint64 drwav_read_pcm_frames_s32__mulaw(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut)
+{
+ drwav_uint64 totalFramesRead;
+ drwav_uint8 sampleData[4096];
+
+ drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
+ if (bytesPerFrame == 0) {
+ return 0;
+ }
+
+ totalFramesRead = 0;
+
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav_mulaw_to_s32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels));
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut)
+{
+ if (pWav == NULL || framesToRead == 0) {
+ return 0;
+ }
+
+ if (pBufferOut == NULL) {
+ return drwav_read_pcm_frames(pWav, framesToRead, NULL);
+ }
+
+ /* Don't try to read more samples than can potentially fit in the output buffer. */
+ if (framesToRead * pWav->channels * sizeof(drwav_int32) > DRWAV_SIZE_MAX) {
+ framesToRead = DRWAV_SIZE_MAX / sizeof(drwav_int32) / pWav->channels;
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM) {
+ return drwav_read_pcm_frames_s32__pcm(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) {
+ return drwav_read_pcm_frames_s32__msadpcm(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_IEEE_FLOAT) {
+ return drwav_read_pcm_frames_s32__ieee(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ALAW) {
+ return drwav_read_pcm_frames_s32__alaw(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_MULAW) {
+ return drwav_read_pcm_frames_s32__mulaw(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) {
+ return drwav_read_pcm_frames_s32__ima(pWav, framesToRead, pBufferOut);
+ }
+
+ return 0;
+}
+
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32le(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut)
+{
+ drwav_uint64 framesRead = drwav_read_pcm_frames_s32(pWav, framesToRead, pBufferOut);
+ if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_FALSE) {
+ drwav__bswap_samples_s32(pBufferOut, framesRead*pWav->channels);
+ }
+
+ return framesRead;
+}
+
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32be(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut)
+{
+ drwav_uint64 framesRead = drwav_read_pcm_frames_s32(pWav, framesToRead, pBufferOut);
+ if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_TRUE) {
+ drwav__bswap_samples_s32(pBufferOut, framesRead*pWav->channels);
+ }
+
+ return framesRead;
+}
+
+
+DRWAV_API void drwav_u8_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount)
+{
+ size_t i;
+
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+ for (i = 0; i < sampleCount; ++i) {
+ *pOut++ = ((int)pIn[i] - 128) << 24;
+ }
+}
+
+DRWAV_API void drwav_s16_to_s32(drwav_int32* pOut, const drwav_int16* pIn, size_t sampleCount)
+{
+ size_t i;
+
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+ for (i = 0; i < sampleCount; ++i) {
+ *pOut++ = pIn[i] << 16;
+ }
+}
+
+DRWAV_API void drwav_s24_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount)
+{
+ size_t i;
+
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+ for (i = 0; i < sampleCount; ++i) {
+ unsigned int s0 = pIn[i*3 + 0];
+ unsigned int s1 = pIn[i*3 + 1];
+ unsigned int s2 = pIn[i*3 + 2];
+
+ drwav_int32 sample32 = (drwav_int32)((s0 << 8) | (s1 << 16) | (s2 << 24));
+ *pOut++ = sample32;
+ }
+}
+
+DRWAV_API void drwav_f32_to_s32(drwav_int32* pOut, const float* pIn, size_t sampleCount)
+{
+ size_t i;
+
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+ for (i = 0; i < sampleCount; ++i) {
+ *pOut++ = (drwav_int32)(2147483648.0 * pIn[i]);
+ }
+}
+
+DRWAV_API void drwav_f64_to_s32(drwav_int32* pOut, const double* pIn, size_t sampleCount)
+{
+ size_t i;
+
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+ for (i = 0; i < sampleCount; ++i) {
+ *pOut++ = (drwav_int32)(2147483648.0 * pIn[i]);
+ }
+}
+
+DRWAV_API void drwav_alaw_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount)
+{
+ size_t i;
+
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+ for (i = 0; i < sampleCount; ++i) {
+ *pOut++ = ((drwav_int32)drwav__alaw_to_s16(pIn[i])) << 16;
+ }
+}
+
+DRWAV_API void drwav_mulaw_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount)
+{
+ size_t i;
+
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+ for (i= 0; i < sampleCount; ++i) {
+ *pOut++ = ((drwav_int32)drwav__mulaw_to_s16(pIn[i])) << 16;
+ }
+}
+
+
+
+static drwav_int16* drwav__read_pcm_frames_and_close_s16(drwav* pWav, unsigned int* channels, unsigned int* sampleRate, drwav_uint64* totalFrameCount)
+{
+ drwav_uint64 sampleDataSize;
+ drwav_int16* pSampleData;
+ drwav_uint64 framesRead;
+
+ DRWAV_ASSERT(pWav != NULL);
+
+ sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(drwav_int16);
+ if (sampleDataSize > DRWAV_SIZE_MAX) {
+ drwav_uninit(pWav);
+ return NULL; /* File's too big. */
+ }
+
+ pSampleData = (drwav_int16*)drwav__malloc_from_callbacks((size_t)sampleDataSize, &pWav->allocationCallbacks); /* <-- Safe cast due to the check above. */
+ if (pSampleData == NULL) {
+ drwav_uninit(pWav);
+ return NULL; /* Failed to allocate memory. */
+ }
+
+ framesRead = drwav_read_pcm_frames_s16(pWav, (size_t)pWav->totalPCMFrameCount, pSampleData);
+ if (framesRead != pWav->totalPCMFrameCount) {
+ drwav__free_from_callbacks(pSampleData, &pWav->allocationCallbacks);
+ drwav_uninit(pWav);
+ return NULL; /* There was an error reading the samples. */
+ }
+
+ drwav_uninit(pWav);
+
+ if (sampleRate) {
+ *sampleRate = pWav->sampleRate;
+ }
+ if (channels) {
+ *channels = pWav->channels;
+ }
+ if (totalFrameCount) {
+ *totalFrameCount = pWav->totalPCMFrameCount;
+ }
+
+ return pSampleData;
+}
+
+static float* drwav__read_pcm_frames_and_close_f32(drwav* pWav, unsigned int* channels, unsigned int* sampleRate, drwav_uint64* totalFrameCount)
+{
+ drwav_uint64 sampleDataSize;
+ float* pSampleData;
+ drwav_uint64 framesRead;
+
+ DRWAV_ASSERT(pWav != NULL);
+
+ sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(float);
+ if (sampleDataSize > DRWAV_SIZE_MAX) {
+ drwav_uninit(pWav);
+ return NULL; /* File's too big. */
+ }
+
+ pSampleData = (float*)drwav__malloc_from_callbacks((size_t)sampleDataSize, &pWav->allocationCallbacks); /* <-- Safe cast due to the check above. */
+ if (pSampleData == NULL) {
+ drwav_uninit(pWav);
+ return NULL; /* Failed to allocate memory. */
+ }
+
+ framesRead = drwav_read_pcm_frames_f32(pWav, (size_t)pWav->totalPCMFrameCount, pSampleData);
+ if (framesRead != pWav->totalPCMFrameCount) {
+ drwav__free_from_callbacks(pSampleData, &pWav->allocationCallbacks);
+ drwav_uninit(pWav);
+ return NULL; /* There was an error reading the samples. */
+ }
+
+ drwav_uninit(pWav);
+
+ if (sampleRate) {
+ *sampleRate = pWav->sampleRate;
+ }
+ if (channels) {
+ *channels = pWav->channels;
+ }
+ if (totalFrameCount) {
+ *totalFrameCount = pWav->totalPCMFrameCount;
+ }
+
+ return pSampleData;
+}
+
+static drwav_int32* drwav__read_pcm_frames_and_close_s32(drwav* pWav, unsigned int* channels, unsigned int* sampleRate, drwav_uint64* totalFrameCount)
+{
+ drwav_uint64 sampleDataSize;
+ drwav_int32* pSampleData;
+ drwav_uint64 framesRead;
+
+ DRWAV_ASSERT(pWav != NULL);
+
+ sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(drwav_int32);
+ if (sampleDataSize > DRWAV_SIZE_MAX) {
+ drwav_uninit(pWav);
+ return NULL; /* File's too big. */
+ }
+
+ pSampleData = (drwav_int32*)drwav__malloc_from_callbacks((size_t)sampleDataSize, &pWav->allocationCallbacks); /* <-- Safe cast due to the check above. */
+ if (pSampleData == NULL) {
+ drwav_uninit(pWav);
+ return NULL; /* Failed to allocate memory. */
+ }
+
+ framesRead = drwav_read_pcm_frames_s32(pWav, (size_t)pWav->totalPCMFrameCount, pSampleData);
+ if (framesRead != pWav->totalPCMFrameCount) {
+ drwav__free_from_callbacks(pSampleData, &pWav->allocationCallbacks);
+ drwav_uninit(pWav);
+ return NULL; /* There was an error reading the samples. */
+ }
+
+ drwav_uninit(pWav);
+
+ if (sampleRate) {
+ *sampleRate = pWav->sampleRate;
+ }
+ if (channels) {
+ *channels = pWav->channels;
+ }
+ if (totalFrameCount) {
+ *totalFrameCount = pWav->totalPCMFrameCount;
+ }
+
+ return pSampleData;
+}
+
+
+
+DRWAV_API drwav_int16* drwav_open_and_read_pcm_frames_s16(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav wav;
+
+ if (channelsOut) {
+ *channelsOut = 0;
+ }
+ if (sampleRateOut) {
+ *sampleRateOut = 0;
+ }
+ if (totalFrameCountOut) {
+ *totalFrameCountOut = 0;
+ }
+
+ if (!drwav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) {
+ return NULL;
+ }
+
+ return drwav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut);
+}
+
+DRWAV_API float* drwav_open_and_read_pcm_frames_f32(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav wav;
+
+ if (channelsOut) {
+ *channelsOut = 0;
+ }
+ if (sampleRateOut) {
+ *sampleRateOut = 0;
+ }
+ if (totalFrameCountOut) {
+ *totalFrameCountOut = 0;
+ }
+
+ if (!drwav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) {
+ return NULL;
+ }
+
+ return drwav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut);
+}
+
+DRWAV_API drwav_int32* drwav_open_and_read_pcm_frames_s32(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav wav;
+
+ if (channelsOut) {
+ *channelsOut = 0;
+ }
+ if (sampleRateOut) {
+ *sampleRateOut = 0;
+ }
+ if (totalFrameCountOut) {
+ *totalFrameCountOut = 0;
+ }
+
+ if (!drwav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) {
+ return NULL;
+ }
+
+ return drwav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut);
+}
+
+#ifndef DR_WAV_NO_STDIO
+DRWAV_API drwav_int16* drwav_open_file_and_read_pcm_frames_s16(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav wav;
+
+ if (channelsOut) {
+ *channelsOut = 0;
+ }
+ if (sampleRateOut) {
+ *sampleRateOut = 0;
+ }
+ if (totalFrameCountOut) {
+ *totalFrameCountOut = 0;
+ }
+
+ if (!drwav_init_file(&wav, filename, pAllocationCallbacks)) {
+ return NULL;
+ }
+
+ return drwav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut);
+}
+
+DRWAV_API float* drwav_open_file_and_read_pcm_frames_f32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav wav;
+
+ if (channelsOut) {
+ *channelsOut = 0;
+ }
+ if (sampleRateOut) {
+ *sampleRateOut = 0;
+ }
+ if (totalFrameCountOut) {
+ *totalFrameCountOut = 0;
+ }
+
+ if (!drwav_init_file(&wav, filename, pAllocationCallbacks)) {
+ return NULL;
+ }
+
+ return drwav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut);
+}
+
+DRWAV_API drwav_int32* drwav_open_file_and_read_pcm_frames_s32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav wav;
+
+ if (channelsOut) {
+ *channelsOut = 0;
+ }
+ if (sampleRateOut) {
+ *sampleRateOut = 0;
+ }
+ if (totalFrameCountOut) {
+ *totalFrameCountOut = 0;
+ }
+
+ if (!drwav_init_file(&wav, filename, pAllocationCallbacks)) {
+ return NULL;
+ }
+
+ return drwav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut);
+}
+
+
+DRWAV_API drwav_int16* drwav_open_file_and_read_pcm_frames_s16_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav wav;
+
+ if (sampleRateOut) {
+ *sampleRateOut = 0;
+ }
+ if (channelsOut) {
+ *channelsOut = 0;
+ }
+ if (totalFrameCountOut) {
+ *totalFrameCountOut = 0;
+ }
+
+ if (!drwav_init_file_w(&wav, filename, pAllocationCallbacks)) {
+ return NULL;
+ }
+
+ return drwav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut);
+}
+
+DRWAV_API float* drwav_open_file_and_read_pcm_frames_f32_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav wav;
+
+ if (sampleRateOut) {
+ *sampleRateOut = 0;
+ }
+ if (channelsOut) {
+ *channelsOut = 0;
+ }
+ if (totalFrameCountOut) {
+ *totalFrameCountOut = 0;
+ }
+
+ if (!drwav_init_file_w(&wav, filename, pAllocationCallbacks)) {
+ return NULL;
+ }
+
+ return drwav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut);
+}
+
+DRWAV_API drwav_int32* drwav_open_file_and_read_pcm_frames_s32_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav wav;
+
+ if (sampleRateOut) {
+ *sampleRateOut = 0;
+ }
+ if (channelsOut) {
+ *channelsOut = 0;
+ }
+ if (totalFrameCountOut) {
+ *totalFrameCountOut = 0;
+ }
+
+ if (!drwav_init_file_w(&wav, filename, pAllocationCallbacks)) {
+ return NULL;
+ }
+
+ return drwav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut);
+}
+#endif
+
+DRWAV_API drwav_int16* drwav_open_memory_and_read_pcm_frames_s16(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav wav;
+
+ if (channelsOut) {
+ *channelsOut = 0;
+ }
+ if (sampleRateOut) {
+ *sampleRateOut = 0;
+ }
+ if (totalFrameCountOut) {
+ *totalFrameCountOut = 0;
+ }
+
+ if (!drwav_init_memory(&wav, data, dataSize, pAllocationCallbacks)) {
+ return NULL;
+ }
+
+ return drwav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut);
+}
+
+DRWAV_API float* drwav_open_memory_and_read_pcm_frames_f32(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav wav;
+
+ if (channelsOut) {
+ *channelsOut = 0;
+ }
+ if (sampleRateOut) {
+ *sampleRateOut = 0;
+ }
+ if (totalFrameCountOut) {
+ *totalFrameCountOut = 0;
+ }
+
+ if (!drwav_init_memory(&wav, data, dataSize, pAllocationCallbacks)) {
+ return NULL;
+ }
+
+ return drwav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut);
+}
+
+DRWAV_API drwav_int32* drwav_open_memory_and_read_pcm_frames_s32(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav wav;
+
+ if (channelsOut) {
+ *channelsOut = 0;
+ }
+ if (sampleRateOut) {
+ *sampleRateOut = 0;
+ }
+ if (totalFrameCountOut) {
+ *totalFrameCountOut = 0;
+ }
+
+ if (!drwav_init_memory(&wav, data, dataSize, pAllocationCallbacks)) {
+ return NULL;
+ }
+
+ return drwav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut);
+}
+#endif /* DR_WAV_NO_CONVERSION_API */
+
+
+DRWAV_API void drwav_free(void* p, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (pAllocationCallbacks != NULL) {
+ drwav__free_from_callbacks(p, pAllocationCallbacks);
+ } else {
+ drwav__free_default(p, NULL);
+ }
+}
+
+DRWAV_API drwav_uint16 drwav_bytes_to_u16(const drwav_uint8* data)
+{
+ return drwav__bytes_to_u16(data);
+}
+
+DRWAV_API drwav_int16 drwav_bytes_to_s16(const drwav_uint8* data)
+{
+ return drwav__bytes_to_s16(data);
+}
+
+DRWAV_API drwav_uint32 drwav_bytes_to_u32(const drwav_uint8* data)
+{
+ return drwav__bytes_to_u32(data);
+}
+
+DRWAV_API drwav_int32 drwav_bytes_to_s32(const drwav_uint8* data)
+{
+ return drwav__bytes_to_s32(data);
+}
+
+DRWAV_API drwav_uint64 drwav_bytes_to_u64(const drwav_uint8* data)
+{
+ return drwav__bytes_to_u64(data);
+}
+
+DRWAV_API drwav_int64 drwav_bytes_to_s64(const drwav_uint8* data)
+{
+ return drwav__bytes_to_s64(data);
+}
+
+
+DRWAV_API drwav_bool32 drwav_guid_equal(const drwav_uint8 a[16], const drwav_uint8 b[16])
+{
+ return drwav__guid_equal(a, b);
+}
+
+DRWAV_API drwav_bool32 drwav_fourcc_equal(const drwav_uint8* a, const char* b)
+{
+ return drwav__fourcc_equal(a, b);
+}
+
+#endif /* dr_wav_c */
+#endif /* DR_WAV_IMPLEMENTATION */
+
+/*
+RELEASE NOTES - v0.11.0
+=======================
+Version 0.11.0 has breaking API changes.
+
+Improved Client-Defined Memory Allocation
+-----------------------------------------
+The main change with this release is the addition of a more flexible way of implementing custom memory allocation routines. The
+existing system of DRWAV_MALLOC, DRWAV_REALLOC and DRWAV_FREE are still in place and will be used by default when no custom
+allocation callbacks are specified.
+
+To use the new system, you pass in a pointer to a drwav_allocation_callbacks object to drwav_init() and family, like this:
+
+ void* my_malloc(size_t sz, void* pUserData)
+ {
+ return malloc(sz);
+ }
+ void* my_realloc(void* p, size_t sz, void* pUserData)
+ {
+ return realloc(p, sz);
+ }
+ void my_free(void* p, void* pUserData)
+ {
+ free(p);
+ }
+
+ ...
+
+ drwav_allocation_callbacks allocationCallbacks;
+ allocationCallbacks.pUserData = &myData;
+ allocationCallbacks.onMalloc = my_malloc;
+ allocationCallbacks.onRealloc = my_realloc;
+ allocationCallbacks.onFree = my_free;
+ drwav_init_file(&wav, "my_file.wav", &allocationCallbacks);
+
+The advantage of this new system is that it allows you to specify user data which will be passed in to the allocation routines.
+
+Passing in null for the allocation callbacks object will cause dr_wav to use defaults which is the same as DRWAV_MALLOC,
+DRWAV_REALLOC and DRWAV_FREE and the equivalent of how it worked in previous versions.
+
+Every API that opens a drwav object now takes this extra parameter. These include the following:
+
+ drwav_init()
+ drwav_init_ex()
+ drwav_init_file()
+ drwav_init_file_ex()
+ drwav_init_file_w()
+ drwav_init_file_w_ex()
+ drwav_init_memory()
+ drwav_init_memory_ex()
+ drwav_init_write()
+ drwav_init_write_sequential()
+ drwav_init_write_sequential_pcm_frames()
+ drwav_init_file_write()
+ drwav_init_file_write_sequential()
+ drwav_init_file_write_sequential_pcm_frames()
+ drwav_init_file_write_w()
+ drwav_init_file_write_sequential_w()
+ drwav_init_file_write_sequential_pcm_frames_w()
+ drwav_init_memory_write()
+ drwav_init_memory_write_sequential()
+ drwav_init_memory_write_sequential_pcm_frames()
+ drwav_open_and_read_pcm_frames_s16()
+ drwav_open_and_read_pcm_frames_f32()
+ drwav_open_and_read_pcm_frames_s32()
+ drwav_open_file_and_read_pcm_frames_s16()
+ drwav_open_file_and_read_pcm_frames_f32()
+ drwav_open_file_and_read_pcm_frames_s32()
+ drwav_open_file_and_read_pcm_frames_s16_w()
+ drwav_open_file_and_read_pcm_frames_f32_w()
+ drwav_open_file_and_read_pcm_frames_s32_w()
+ drwav_open_memory_and_read_pcm_frames_s16()
+ drwav_open_memory_and_read_pcm_frames_f32()
+ drwav_open_memory_and_read_pcm_frames_s32()
+
+Endian Improvements
+-------------------
+Previously, the following APIs returned little-endian audio data. These now return native-endian data. This improves compatibility
+on big-endian architectures.
+
+ drwav_read_pcm_frames()
+ drwav_read_pcm_frames_s16()
+ drwav_read_pcm_frames_s32()
+ drwav_read_pcm_frames_f32()
+ drwav_open_and_read_pcm_frames_s16()
+ drwav_open_and_read_pcm_frames_s32()
+ drwav_open_and_read_pcm_frames_f32()
+ drwav_open_file_and_read_pcm_frames_s16()
+ drwav_open_file_and_read_pcm_frames_s32()
+ drwav_open_file_and_read_pcm_frames_f32()
+ drwav_open_file_and_read_pcm_frames_s16_w()
+ drwav_open_file_and_read_pcm_frames_s32_w()
+ drwav_open_file_and_read_pcm_frames_f32_w()
+ drwav_open_memory_and_read_pcm_frames_s16()
+ drwav_open_memory_and_read_pcm_frames_s32()
+ drwav_open_memory_and_read_pcm_frames_f32()
+
+APIs have been added to give you explicit control over whether or not audio data is read or written in big- or little-endian byte
+order:
+
+ drwav_read_pcm_frames_le()
+ drwav_read_pcm_frames_be()
+ drwav_read_pcm_frames_s16le()
+ drwav_read_pcm_frames_s16be()
+ drwav_read_pcm_frames_f32le()
+ drwav_read_pcm_frames_f32be()
+ drwav_read_pcm_frames_s32le()
+ drwav_read_pcm_frames_s32be()
+ drwav_write_pcm_frames_le()
+ drwav_write_pcm_frames_be()
+
+Removed APIs
+------------
+The following APIs were deprecated in version 0.10.0 and have now been removed:
+
+ drwav_open()
+ drwav_open_ex()
+ drwav_open_write()
+ drwav_open_write_sequential()
+ drwav_open_file()
+ drwav_open_file_ex()
+ drwav_open_file_write()
+ drwav_open_file_write_sequential()
+ drwav_open_memory()
+ drwav_open_memory_ex()
+ drwav_open_memory_write()
+ drwav_open_memory_write_sequential()
+ drwav_close()
+
+
+
+RELEASE NOTES - v0.10.0
+=======================
+Version 0.10.0 has breaking API changes. There are no significant bug fixes in this release, so if you are affected you do
+not need to upgrade.
+
+Removed APIs
+------------
+The following APIs were deprecated in version 0.9.0 and have been completely removed in version 0.10.0:
+
+ drwav_read()
+ drwav_read_s16()
+ drwav_read_f32()
+ drwav_read_s32()
+ drwav_seek_to_sample()
+ drwav_write()
+ drwav_open_and_read_s16()
+ drwav_open_and_read_f32()
+ drwav_open_and_read_s32()
+ drwav_open_file_and_read_s16()
+ drwav_open_file_and_read_f32()
+ drwav_open_file_and_read_s32()
+ drwav_open_memory_and_read_s16()
+ drwav_open_memory_and_read_f32()
+ drwav_open_memory_and_read_s32()
+ drwav::totalSampleCount
+
+See release notes for version 0.9.0 at the bottom of this file for replacement APIs.
+
+Deprecated APIs
+---------------
+The following APIs have been deprecated. There is a confusing and completely arbitrary difference between drwav_init*() and
+drwav_open*(), where drwav_init*() initializes a pre-allocated drwav object, whereas drwav_open*() will first allocated a
+drwav object on the heap and then initialize it. drwav_open*() has been deprecated which means you must now use a pre-
+allocated drwav object with drwav_init*(). If you need the previous functionality, you can just do a malloc() followed by
+a called to one of the drwav_init*() APIs.
+
+ drwav_open()
+ drwav_open_ex()
+ drwav_open_write()
+ drwav_open_write_sequential()
+ drwav_open_file()
+ drwav_open_file_ex()
+ drwav_open_file_write()
+ drwav_open_file_write_sequential()
+ drwav_open_memory()
+ drwav_open_memory_ex()
+ drwav_open_memory_write()
+ drwav_open_memory_write_sequential()
+ drwav_close()
+
+These APIs will be removed completely in a future version. The rationale for this change is to remove confusion between the
+two different ways to initialize a drwav object.
+*/
+
+/*
+REVISION HISTORY
+================
+v0.12.16 - 2020-12-02
+ - Fix a bug when trying to read more bytes than can fit in a size_t.
+
+v0.12.15 - 2020-11-21
+ - Fix compilation with OpenWatcom.
+
+v0.12.14 - 2020-11-13
+ - Minor code clean up.
+
+v0.12.13 - 2020-11-01
+ - Improve compiler support for older versions of GCC.
+
+v0.12.12 - 2020-09-28
+ - Add support for RF64.
+ - Fix a bug in writing mode where the size of the RIFF chunk incorrectly includes the header section.
+
+v0.12.11 - 2020-09-08
+ - Fix a compilation error on older compilers.
+
+v0.12.10 - 2020-08-24
+ - Fix a bug when seeking with ADPCM formats.
+
+v0.12.9 - 2020-08-02
+ - Simplify sized types.
+
+v0.12.8 - 2020-07-25
+ - Fix a compilation warning.
+
+v0.12.7 - 2020-07-15
+ - Fix some bugs on big-endian architectures.
+ - Fix an error in s24 to f32 conversion.
+
+v0.12.6 - 2020-06-23
+ - Change drwav_read_*() to allow NULL to be passed in as the output buffer which is equivalent to a forward seek.
+ - Fix a buffer overflow when trying to decode invalid IMA-ADPCM files.
+ - Add include guard for the implementation section.
+
+v0.12.5 - 2020-05-27
+ - Minor documentation fix.
+
+v0.12.4 - 2020-05-16
+ - Replace assert() with DRWAV_ASSERT().
+ - Add compile-time and run-time version querying.
+ - DRWAV_VERSION_MINOR
+ - DRWAV_VERSION_MAJOR
+ - DRWAV_VERSION_REVISION
+ - DRWAV_VERSION_STRING
+ - drwav_version()
+ - drwav_version_string()
+
+v0.12.3 - 2020-04-30
+ - Fix compilation errors with VC6.
+
+v0.12.2 - 2020-04-21
+ - Fix a bug where drwav_init_file() does not close the file handle after attempting to load an erroneous file.
+
+v0.12.1 - 2020-04-13
+ - Fix some pedantic warnings.
+
+v0.12.0 - 2020-04-04
+ - API CHANGE: Add container and format parameters to the chunk callback.
+ - Minor documentation updates.
+
+v0.11.5 - 2020-03-07
+ - Fix compilation error with Visual Studio .NET 2003.
+
+v0.11.4 - 2020-01-29
+ - Fix some static analysis warnings.
+ - Fix a bug when reading f32 samples from an A-law encoded stream.
+
+v0.11.3 - 2020-01-12
+ - Minor changes to some f32 format conversion routines.
+ - Minor bug fix for ADPCM conversion when end of file is reached.
+
+v0.11.2 - 2019-12-02
+ - Fix a possible crash when using custom memory allocators without a custom realloc() implementation.
+ - Fix an integer overflow bug.
+ - Fix a null pointer dereference bug.
+ - Add limits to sample rate, channels and bits per sample to tighten up some validation.
+
+v0.11.1 - 2019-10-07
+ - Internal code clean up.
+
+v0.11.0 - 2019-10-06
+ - API CHANGE: Add support for user defined memory allocation routines. This system allows the program to specify their own memory allocation
+ routines with a user data pointer for client-specific contextual data. This adds an extra parameter to the end of the following APIs:
+ - drwav_init()
+ - drwav_init_ex()
+ - drwav_init_file()
+ - drwav_init_file_ex()
+ - drwav_init_file_w()
+ - drwav_init_file_w_ex()
+ - drwav_init_memory()
+ - drwav_init_memory_ex()
+ - drwav_init_write()
+ - drwav_init_write_sequential()
+ - drwav_init_write_sequential_pcm_frames()
+ - drwav_init_file_write()
+ - drwav_init_file_write_sequential()
+ - drwav_init_file_write_sequential_pcm_frames()
+ - drwav_init_file_write_w()
+ - drwav_init_file_write_sequential_w()
+ - drwav_init_file_write_sequential_pcm_frames_w()
+ - drwav_init_memory_write()
+ - drwav_init_memory_write_sequential()
+ - drwav_init_memory_write_sequential_pcm_frames()
+ - drwav_open_and_read_pcm_frames_s16()
+ - drwav_open_and_read_pcm_frames_f32()
+ - drwav_open_and_read_pcm_frames_s32()
+ - drwav_open_file_and_read_pcm_frames_s16()
+ - drwav_open_file_and_read_pcm_frames_f32()
+ - drwav_open_file_and_read_pcm_frames_s32()
+ - drwav_open_file_and_read_pcm_frames_s16_w()
+ - drwav_open_file_and_read_pcm_frames_f32_w()
+ - drwav_open_file_and_read_pcm_frames_s32_w()
+ - drwav_open_memory_and_read_pcm_frames_s16()
+ - drwav_open_memory_and_read_pcm_frames_f32()
+ - drwav_open_memory_and_read_pcm_frames_s32()
+ Set this extra parameter to NULL to use defaults which is the same as the previous behaviour. Setting this NULL will use
+ DRWAV_MALLOC, DRWAV_REALLOC and DRWAV_FREE.
+ - Add support for reading and writing PCM frames in an explicit endianness. New APIs:
+ - drwav_read_pcm_frames_le()
+ - drwav_read_pcm_frames_be()
+ - drwav_read_pcm_frames_s16le()
+ - drwav_read_pcm_frames_s16be()
+ - drwav_read_pcm_frames_f32le()
+ - drwav_read_pcm_frames_f32be()
+ - drwav_read_pcm_frames_s32le()
+ - drwav_read_pcm_frames_s32be()
+ - drwav_write_pcm_frames_le()
+ - drwav_write_pcm_frames_be()
+ - Remove deprecated APIs.
+ - API CHANGE: The following APIs now return native-endian data. Previously they returned little-endian data.
+ - drwav_read_pcm_frames()
+ - drwav_read_pcm_frames_s16()
+ - drwav_read_pcm_frames_s32()
+ - drwav_read_pcm_frames_f32()
+ - drwav_open_and_read_pcm_frames_s16()
+ - drwav_open_and_read_pcm_frames_s32()
+ - drwav_open_and_read_pcm_frames_f32()
+ - drwav_open_file_and_read_pcm_frames_s16()
+ - drwav_open_file_and_read_pcm_frames_s32()
+ - drwav_open_file_and_read_pcm_frames_f32()
+ - drwav_open_file_and_read_pcm_frames_s16_w()
+ - drwav_open_file_and_read_pcm_frames_s32_w()
+ - drwav_open_file_and_read_pcm_frames_f32_w()
+ - drwav_open_memory_and_read_pcm_frames_s16()
+ - drwav_open_memory_and_read_pcm_frames_s32()
+ - drwav_open_memory_and_read_pcm_frames_f32()
+
+v0.10.1 - 2019-08-31
+ - Correctly handle partial trailing ADPCM blocks.
+
+v0.10.0 - 2019-08-04
+ - Remove deprecated APIs.
+ - Add wchar_t variants for file loading APIs:
+ drwav_init_file_w()
+ drwav_init_file_ex_w()
+ drwav_init_file_write_w()
+ drwav_init_file_write_sequential_w()
+ - Add drwav_target_write_size_bytes() which calculates the total size in bytes of a WAV file given a format and sample count.
+ - Add APIs for specifying the PCM frame count instead of the sample count when opening in sequential write mode:
+ drwav_init_write_sequential_pcm_frames()
+ drwav_init_file_write_sequential_pcm_frames()
+ drwav_init_file_write_sequential_pcm_frames_w()
+ drwav_init_memory_write_sequential_pcm_frames()
+ - Deprecate drwav_open*() and drwav_close():
+ drwav_open()
+ drwav_open_ex()
+ drwav_open_write()
+ drwav_open_write_sequential()
+ drwav_open_file()
+ drwav_open_file_ex()
+ drwav_open_file_write()
+ drwav_open_file_write_sequential()
+ drwav_open_memory()
+ drwav_open_memory_ex()
+ drwav_open_memory_write()
+ drwav_open_memory_write_sequential()
+ drwav_close()
+ - Minor documentation updates.
+
+v0.9.2 - 2019-05-21
+ - Fix warnings.
+
+v0.9.1 - 2019-05-05
+ - Add support for C89.
+ - Change license to choice of public domain or MIT-0.
+
+v0.9.0 - 2018-12-16
+ - API CHANGE: Add new reading APIs for reading by PCM frames instead of samples. Old APIs have been deprecated and
+ will be removed in v0.10.0. Deprecated APIs and their replacements:
+ drwav_read() -> drwav_read_pcm_frames()
+ drwav_read_s16() -> drwav_read_pcm_frames_s16()
+ drwav_read_f32() -> drwav_read_pcm_frames_f32()
+ drwav_read_s32() -> drwav_read_pcm_frames_s32()
+ drwav_seek_to_sample() -> drwav_seek_to_pcm_frame()
+ drwav_write() -> drwav_write_pcm_frames()
+ drwav_open_and_read_s16() -> drwav_open_and_read_pcm_frames_s16()
+ drwav_open_and_read_f32() -> drwav_open_and_read_pcm_frames_f32()
+ drwav_open_and_read_s32() -> drwav_open_and_read_pcm_frames_s32()
+ drwav_open_file_and_read_s16() -> drwav_open_file_and_read_pcm_frames_s16()
+ drwav_open_file_and_read_f32() -> drwav_open_file_and_read_pcm_frames_f32()
+ drwav_open_file_and_read_s32() -> drwav_open_file_and_read_pcm_frames_s32()
+ drwav_open_memory_and_read_s16() -> drwav_open_memory_and_read_pcm_frames_s16()
+ drwav_open_memory_and_read_f32() -> drwav_open_memory_and_read_pcm_frames_f32()
+ drwav_open_memory_and_read_s32() -> drwav_open_memory_and_read_pcm_frames_s32()
+ drwav::totalSampleCount -> drwav::totalPCMFrameCount
+ - API CHANGE: Rename drwav_open_and_read_file_*() to drwav_open_file_and_read_*().
+ - API CHANGE: Rename drwav_open_and_read_memory_*() to drwav_open_memory_and_read_*().
+ - Add built-in support for smpl chunks.
+ - Add support for firing a callback for each chunk in the file at initialization time.
+ - This is enabled through the drwav_init_ex(), etc. family of APIs.
+ - Handle invalid FMT chunks more robustly.
+
+v0.8.5 - 2018-09-11
+ - Const correctness.
+ - Fix a potential stack overflow.
+
+v0.8.4 - 2018-08-07
+ - Improve 64-bit detection.
+
+v0.8.3 - 2018-08-05
+ - Fix C++ build on older versions of GCC.
+
+v0.8.2 - 2018-08-02
+ - Fix some big-endian bugs.
+
+v0.8.1 - 2018-06-29
+ - Add support for sequential writing APIs.
+ - Disable seeking in write mode.
+ - Fix bugs with Wave64.
+ - Fix typos.
+
+v0.8 - 2018-04-27
+ - Bug fix.
+ - Start using major.minor.revision versioning.
+
+v0.7f - 2018-02-05
+ - Restrict ADPCM formats to a maximum of 2 channels.
+
+v0.7e - 2018-02-02
+ - Fix a crash.
+
+v0.7d - 2018-02-01
+ - Fix a crash.
+
+v0.7c - 2018-02-01
+ - Set drwav.bytesPerSample to 0 for all compressed formats.
+ - Fix a crash when reading 16-bit floating point WAV files. In this case dr_wav will output silence for
+ all format conversion reading APIs (*_s16, *_s32, *_f32 APIs).
+ - Fix some divide-by-zero errors.
+
+v0.7b - 2018-01-22
+ - Fix errors with seeking of compressed formats.
+ - Fix compilation error when DR_WAV_NO_CONVERSION_API
+
+v0.7a - 2017-11-17
+ - Fix some GCC warnings.
+
+v0.7 - 2017-11-04
+ - Add writing APIs.
+
+v0.6 - 2017-08-16
+ - API CHANGE: Rename dr_* types to drwav_*.
+ - Add support for custom implementations of malloc(), realloc(), etc.
+ - Add support for Microsoft ADPCM.
+ - Add support for IMA ADPCM (DVI, format code 0x11).
+ - Optimizations to drwav_read_s16().
+ - Bug fixes.
+
+v0.5g - 2017-07-16
+ - Change underlying type for booleans to unsigned.
+
+v0.5f - 2017-04-04
+ - Fix a minor bug with drwav_open_and_read_s16() and family.
+
+v0.5e - 2016-12-29
+ - Added support for reading samples as signed 16-bit integers. Use the _s16() family of APIs for this.
+ - Minor fixes to documentation.
+
+v0.5d - 2016-12-28
+ - Use drwav_int* and drwav_uint* sized types to improve compiler support.
+
+v0.5c - 2016-11-11
+ - Properly handle JUNK chunks that come before the FMT chunk.
+
+v0.5b - 2016-10-23
+ - A minor change to drwav_bool8 and drwav_bool32 types.
+
+v0.5a - 2016-10-11
+ - Fixed a bug with drwav_open_and_read() and family due to incorrect argument ordering.
+ - Improve A-law and mu-law efficiency.
+
+v0.5 - 2016-09-29
+ - API CHANGE. Swap the order of "channels" and "sampleRate" parameters in drwav_open_and_read*(). Rationale for this is to
+ keep it consistent with dr_audio and dr_flac.
+
+v0.4b - 2016-09-18
+ - Fixed a typo in documentation.
+
+v0.4a - 2016-09-18
+ - Fixed a typo.
+ - Change date format to ISO 8601 (YYYY-MM-DD)
+
+v0.4 - 2016-07-13
+ - API CHANGE. Make onSeek consistent with dr_flac.
+ - API CHANGE. Rename drwav_seek() to drwav_seek_to_sample() for clarity and consistency with dr_flac.
+ - Added support for Sony Wave64.
+
+v0.3a - 2016-05-28
+ - API CHANGE. Return drwav_bool32 instead of int in onSeek callback.
+ - Fixed a memory leak.
+
+v0.3 - 2016-05-22
+ - Lots of API changes for consistency.
+
+v0.2a - 2016-05-16
+ - Fixed Linux/GCC build.
+
+v0.2 - 2016-05-11
+ - Added support for reading data as signed 32-bit PCM for consistency with dr_flac.
+
+v0.1a - 2016-05-07
+ - Fixed a bug in drwav_open_file() where the file handle would not be closed if the loader failed to initialize.
+
+v0.1 - 2016-05-04
+ - Initial versioned release.
+*/
+
+/*
+This software is available as a choice of the following licenses. Choose
+whichever you prefer.
+
+===============================================================================
+ALTERNATIVE 1 - Public Domain (www.unlicense.org)
+===============================================================================
+This is free and unencumbered software released into the public domain.
+
+Anyone is free to copy, modify, publish, use, compile, sell, or distribute this
+software, either in source code form or as a compiled binary, for any purpose,
+commercial or non-commercial, and by any means.
+
+In jurisdictions that recognize copyright laws, the author or authors of this
+software dedicate any and all copyright interest in the software to the public
+domain. We make this dedication for the benefit of the public at large and to
+the detriment of our heirs and successors. We intend this dedication to be an
+overt act of relinquishment in perpetuity of all present and future rights to
+this software under copyright law.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
+WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+For more information, please refer to <http://unlicense.org/>
+
+===============================================================================
+ALTERNATIVE 2 - MIT No Attribution
+===============================================================================
+Copyright 2020 David Reid
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of
+this software and associated documentation files (the "Software"), to deal in
+the Software without restriction, including without limitation the rights to
+use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
+of the Software, and to permit persons to whom the Software is furnished to do
+so.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+*/
--- /dev/null
+#include "ggml.h"
+
+#include <assert.h>
+#include <time.h>
+#include <math.h>
+#include <stdlib.h>
+#include <string.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdatomic.h>
+
+#include <pthread.h>
+
+#define GGML_DEBUG 0
+#define GGML_MEM_ALIGN 16
+
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+
+#define UNUSED(x) (void)(x)
+#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
+
+#define GGML_ASSERT(x) assert(x)
+
+#ifdef GGML_USE_ACCELERATE
+#include <Accelerate/Accelerate.h>
+#endif
+
+// floating point type used to accumulate sums
+typedef double ggml_float;
+
+// 16-bit float
+// on Arm, we use __fp16
+// on x86, we use uint16_t
+#ifdef __ARM_NEON
+
+// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
+//
+// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
+//
+#include <arm_neon.h>
+
+float ggml_fp16_to_fp32(ggml_fp16_t x) {
+ return x;
+}
+
+ggml_fp16_t ggml_fp32_to_fp16(float x) {
+ return x;
+}
+
+#else
+
+#include <immintrin.h>
+
+static inline float fp32_from_bits(uint32_t w) {
+ union {
+ uint32_t as_bits;
+ float as_value;
+ } fp32 = { w };
+ return fp32.as_value;
+}
+
+static inline uint32_t fp32_to_bits(float f) {
+ union {
+ float as_value;
+ uint32_t as_bits;
+ } fp32 = { f };
+ return fp32.as_bits;
+}
+
+float ggml_fp16_to_fp32(ggml_fp16_t h) {
+ const uint32_t w = (uint32_t) h << 16;
+ const uint32_t sign = w & UINT32_C(0x80000000);
+ const uint32_t two_w = w + w;
+
+ const uint32_t exp_offset = UINT32_C(0xE0) << 23;
+#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
+ const float exp_scale = 0x1.0p-112f;
+#else
+ const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
+#endif
+ const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
+
+ const uint32_t magic_mask = UINT32_C(126) << 23;
+ const float magic_bias = 0.5f;
+ const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
+
+ const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
+ const uint32_t result = sign |
+ (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
+ return fp32_from_bits(result);
+}
+
+ggml_fp16_t ggml_fp32_to_fp16(float f) {
+#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
+ const float scale_to_inf = 0x1.0p+112f;
+ const float scale_to_zero = 0x1.0p-110f;
+#else
+ const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
+ const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
+#endif
+ float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
+
+ const uint32_t w = fp32_to_bits(f);
+ const uint32_t shl1_w = w + w;
+ const uint32_t sign = w & UINT32_C(0x80000000);
+ uint32_t bias = shl1_w & UINT32_C(0xFF000000);
+ if (bias < UINT32_C(0x71000000)) {
+ bias = UINT32_C(0x71000000);
+ }
+
+ base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
+ const uint32_t bits = fp32_to_bits(base);
+ const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
+ const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
+ const uint32_t nonsign = exp_bits + mantissa_bits;
+ return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
+}
+#endif
+
+//
+// timing
+//
+
+int64_t ggml_time_ms(void) {
+ struct timespec ts;
+ clock_gettime(CLOCK_MONOTONIC, &ts);
+ return (int64_t)ts.tv_sec*1000 + (int64_t)ts.tv_nsec/1000000;
+}
+
+int64_t ggml_time_us(void) {
+ struct timespec ts;
+ clock_gettime(CLOCK_MONOTONIC, &ts);
+ return (int64_t)ts.tv_sec*1000000 + (int64_t)ts.tv_nsec/1000;
+}
+
+int64_t ggml_cycles(void) {
+ return clock();
+}
+
+int64_t ggml_cycles_per_ms(void) {
+ return CLOCKS_PER_SEC/1000;
+}
+
+#ifdef GGML_PERF
+#define ggml_perf_time_ms() ggml_time_ms()
+#define ggml_perf_time_us() ggml_time_us()
+#define ggml_perf_cycles() ggml_cycles()
+#define ggml_perf_cycles_per_ms() ggml_cycles_per_ms()
+#else
+#define ggml_perf_time_ms() 0
+#define ggml_perf_time_us() 0
+#define ggml_perf_cycles() 0
+#define ggml_perf_cycles_per_ms() 0
+#endif
+
+//
+// cache line
+//
+
+#if defined(__cpp_lib_hardware_interference_size)
+ const size_t CACHE_LINE_SIZE = hardware_destructive_interference_size;
+#else
+ const size_t CACHE_LINE_SIZE = 64;
+#endif
+
+const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
+
+//
+// fundamental operations
+//
+
+inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+
+inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+
+inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+
+inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+
+inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
+inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
+inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
+inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; }
+inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; }
+inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
+inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; }
+inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
+inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
+
+inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
+ ggml_float sumf = 0.0;
+#ifdef __ARM_NEON
+ // NEON 128-bit
+ const int n16 = (n & ~15);
+
+ float32x4_t sum0 = vdupq_n_f32(0);
+ float32x4_t sum1 = vdupq_n_f32(0);
+ float32x4_t sum2 = vdupq_n_f32(0);
+ float32x4_t sum3 = vdupq_n_f32(0);
+
+ float32x4_t x0, x1, x2, x3;
+ float32x4_t y0, y1, y2, y3;
+
+ for (int i = 0; i < n16; i += 16) {
+ x0 = vld1q_f32(x + i + 0);
+ x1 = vld1q_f32(x + i + 4);
+ x2 = vld1q_f32(x + i + 8);
+ x3 = vld1q_f32(x + i + 12);
+
+ y0 = vld1q_f32(y + i + 0);
+ y1 = vld1q_f32(y + i + 4);
+ y2 = vld1q_f32(y + i + 8);
+ y3 = vld1q_f32(y + i + 12);
+
+ sum0 = vfmaq_f32(sum0, x0, y0);
+ sum1 = vfmaq_f32(sum1, x1, y1);
+ sum2 = vfmaq_f32(sum2, x2, y2);
+ sum3 = vfmaq_f32(sum3, x3, y3);
+ }
+
+ // reduce sum0..sum3 to sum0
+ sum0 = vaddq_f32(sum0, sum1);
+ sum2 = vaddq_f32(sum2, sum3);
+ sum0 = vaddq_f32(sum0, sum2);
+
+ float32x2_t sumf32 = vadd_f32(vget_low_f32(sum0), vget_high_f32(sum0));
+ sumf = vget_lane_f32(sumf32, 0) + vget_lane_f32(sumf32, 1);
+
+ // leftovers
+ for (int i = n16; i < n; ++i) {
+ sumf += x[i]*y[i];
+ }
+#elif defined(__AVX2__)
+ // AVX 256-bit (unroll 4)
+ const int n32 = (n & ~31);
+
+ __m256 sum0 = _mm256_setzero_ps();
+ __m256 sum1 = _mm256_setzero_ps();
+ __m256 sum2 = _mm256_setzero_ps();
+ __m256 sum3 = _mm256_setzero_ps();
+
+ __m256 x0, x1, x2, x3;
+ __m256 y0, y1, y2, y3;
+
+ for (int i = 0; i < n32; i += 32) {
+ x0 = _mm256_loadu_ps(x + i + 0);
+ x1 = _mm256_loadu_ps(x + i + 8);
+ x2 = _mm256_loadu_ps(x + i + 16);
+ x3 = _mm256_loadu_ps(x + i + 24);
+
+ y0 = _mm256_loadu_ps(y + i + 0);
+ y1 = _mm256_loadu_ps(y + i + 8);
+ y2 = _mm256_loadu_ps(y + i + 16);
+ y3 = _mm256_loadu_ps(y + i + 24);
+
+ sum0 = _mm256_fmadd_ps(x0, y0, sum0);
+ sum1 = _mm256_fmadd_ps(x1, y1, sum1);
+ sum2 = _mm256_fmadd_ps(x2, y2, sum2);
+ sum3 = _mm256_fmadd_ps(x3, y3, sum3);
+ }
+
+ sum0 = _mm256_add_ps(sum0, sum1);
+ sum2 = _mm256_add_ps(sum2, sum3);
+ sum0 = _mm256_add_ps(sum0, sum2);
+
+ const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0), _mm256_extractf128_ps(sum0, 1));
+ const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4));
+ const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2));
+
+ sumf = _mm_cvtss_f32(r1);
+
+ // leftovers
+ for (int i = n32; i < n; ++i) {
+ sumf += x[i]*y[i];
+ }
+#else
+ // scalar
+ for (int i = 0; i < n; ++i) {
+ sumf += x[i]*y[i];
+ }
+#endif
+
+ *s = sumf;
+}
+
+inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
+ ggml_float sumf = 0.0;
+#ifdef __ARM_NEON
+ const int n32 = (n & ~31);
+
+ float16x8_t sum0 = vdupq_n_f16(0);
+ float16x8_t sum1 = vdupq_n_f16(0);
+ float16x8_t sum2 = vdupq_n_f16(0);
+ float16x8_t sum3 = vdupq_n_f16(0);
+
+ float16x8_t x0, x1, x2, x3;
+ float16x8_t y0, y1, y2, y3;
+
+ for (int i = 0; i < n32; i += 32) {
+ x0 = vld1q_f16(x + i + 0 );
+ x1 = vld1q_f16(x + i + 8 );
+ x2 = vld1q_f16(x + i + 16);
+ x3 = vld1q_f16(x + i + 24);
+
+ y0 = vld1q_f16(y + i + 0 );
+ y1 = vld1q_f16(y + i + 8 );
+ y2 = vld1q_f16(y + i + 16);
+ y3 = vld1q_f16(y + i + 24);
+
+ sum0 = vfmaq_f16(sum0, x0, y0);
+ sum1 = vfmaq_f16(sum1, x1, y1);
+ sum2 = vfmaq_f16(sum2, x2, y2);
+ sum3 = vfmaq_f16(sum3, x3, y3);
+ }
+
+ // reduce sum0..sum3 to sum0
+ sum0 = vaddq_f16(sum0, sum1);
+ sum2 = vaddq_f16(sum2, sum3);
+ sum0 = vaddq_f16(sum0, sum2);
+
+ // load sum0 into 2 float32x4_t
+ float32x4_t sum0f32 = vcvt_f32_f16(vget_low_f16(sum0));
+ float32x4_t sum1f32 = vcvt_f32_f16(vget_high_f16(sum0));
+
+ // reduce sum0f32 and sum1f32 to sumf
+ sum0f32 = vaddq_f32(sum0f32, sum1f32);
+
+ float32x2_t sumf32 = vadd_f32(vget_low_f32(sum0f32), vget_high_f32(sum0f32));
+ sumf = vget_lane_f32(sumf32, 0) + vget_lane_f32(sumf32, 1);
+
+ // leftovers
+ for (int i = n32; i < n; ++i) {
+ GGML_ASSERT(false); // should not end up here
+ sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
+ }
+#elif defined(__AVX2__)
+ // AVX 256-bit (unroll 4)
+ const int n32 = (n & ~31);
+
+ __m256 sum0 = _mm256_setzero_ps();
+ __m256 sum1 = _mm256_setzero_ps();
+ __m256 sum2 = _mm256_setzero_ps();
+ __m256 sum3 = _mm256_setzero_ps();
+
+ __m256 x0, x1, x2, x3;
+ __m256 y0, y1, y2, y3;
+
+ for (int i = 0; i < n32; i += 32) {
+ x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 )));
+ x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 )));
+ x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16)));
+ x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24)));
+
+ y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 )));
+ y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 )));
+ y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16)));
+ y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24)));
+
+ sum0 = _mm256_fmadd_ps(x0, y0, sum0);
+ sum1 = _mm256_fmadd_ps(x1, y1, sum1);
+ sum2 = _mm256_fmadd_ps(x2, y2, sum2);
+ sum3 = _mm256_fmadd_ps(x3, y3, sum3);
+ }
+
+ const __m256 sum01 = _mm256_add_ps(sum0, sum1);
+ const __m256 sum23 = _mm256_add_ps(sum2, sum3);
+ const __m256 sum0123 = _mm256_add_ps(sum01, sum23);
+
+ const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0123), _mm256_extractf128_ps(sum0123, 1));
+ const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4));
+ const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2));
+
+ sumf = _mm_cvtss_f32(r1);
+
+ // leftovers
+ for (int i = n32; i < n; ++i) {
+ GGML_ASSERT(false);
+ sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
+ }
+#else
+ for (int i = 0; i < n; ++i) {
+ sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
+ }
+#endif
+
+ *s = sumf;
+}
+
+inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
+#ifdef __ARM_NEON
+ // NEON 128-bit
+ const int n16 = (n & ~15);
+
+ const float32x4_t v4 = vdupq_n_f32(v);
+
+ float32x4_t x0, x1, x2, x3;
+ float32x4_t y0, y1, y2, y3;
+
+ for (int i = 0; i < n16; i += 16) {
+ x0 = vld1q_f32(x + i + 0);
+ x1 = vld1q_f32(x + i + 4);
+ x2 = vld1q_f32(x + i + 8);
+ x3 = vld1q_f32(x + i + 12);
+
+ y0 = vld1q_f32(y + i + 0);
+ y1 = vld1q_f32(y + i + 4);
+ y2 = vld1q_f32(y + i + 8);
+ y3 = vld1q_f32(y + i + 12);
+
+ y0 = vfmaq_f32(y0, x0, v4);
+ y1 = vfmaq_f32(y1, x1, v4);
+ y2 = vfmaq_f32(y2, x2, v4);
+ y3 = vfmaq_f32(y3, x3, v4);
+
+ vst1q_f32(y + i + 0, y0);
+ vst1q_f32(y + i + 4, y1);
+ vst1q_f32(y + i + 8, y2);
+ vst1q_f32(y + i + 12, y3);
+ }
+
+ // leftovers
+ for (int i = n16; i < n; ++i) {
+ y[i] += x[i]*v;
+ }
+#elif defined(__AVX2__)
+ // AVX 256-bit (unroll 4)
+ const int n32 = (n & ~31);
+
+ const __m256 v4 = _mm256_set1_ps(v);
+
+ __m256 x0, x1, x2, x3;
+ __m256 y0, y1, y2, y3;
+
+ for (int i = 0; i < n32; i += 32) {
+ x0 = _mm256_loadu_ps(x + i + 0);
+ x1 = _mm256_loadu_ps(x + i + 8);
+ x2 = _mm256_loadu_ps(x + i + 16);
+ x3 = _mm256_loadu_ps(x + i + 24);
+
+ y0 = _mm256_loadu_ps(y + i + 0);
+ y1 = _mm256_loadu_ps(y + i + 8);
+ y2 = _mm256_loadu_ps(y + i + 16);
+ y3 = _mm256_loadu_ps(y + i + 24);
+
+ y0 = _mm256_fmadd_ps(x0, v4, y0);
+ y1 = _mm256_fmadd_ps(x1, v4, y1);
+ y2 = _mm256_fmadd_ps(x2, v4, y2);
+ y3 = _mm256_fmadd_ps(x3, v4, y3);
+
+ _mm256_storeu_ps(y + i + 0, y0);
+ _mm256_storeu_ps(y + i + 8, y1);
+ _mm256_storeu_ps(y + i + 16, y2);
+ _mm256_storeu_ps(y + i + 24, y3);
+ }
+
+ // leftovers
+ for (int i = n32; i < n; ++i) {
+ y[i] += x[i]*v;
+ }
+#else
+ // scalar
+ for (int i = 0; i < n; ++i) {
+ y[i] += x[i]*v;
+ }
+#endif
+}
+
+inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_fp16_t * restrict x, const float v) {
+#ifdef __ARM_NEON
+ // NEON 128-bit
+ const int n32 = (n & ~31);
+
+ const float16x8_t v8 = vdupq_n_f16(v);
+
+ float16x8_t x0, x1, x2, x3;
+ float16x8_t y0, y1, y2, y3;
+
+ for (int i = 0; i < n32; i += 32) {
+ y0 = vld1q_f16(y + i + 0 );
+ y1 = vld1q_f16(y + i + 8 );
+ y2 = vld1q_f16(y + i + 16);
+ y3 = vld1q_f16(y + i + 24);
+
+ x0 = vld1q_f16(x + i + 0 );
+ x1 = vld1q_f16(x + i + 8 );
+ x2 = vld1q_f16(x + i + 16);
+ x3 = vld1q_f16(x + i + 24);
+
+ y0 = vfmaq_f16(y0, x0, v8);
+ y1 = vfmaq_f16(y1, x1, v8);
+ y2 = vfmaq_f16(y2, x2, v8);
+ y3 = vfmaq_f16(y3, x3, v8);
+
+ vst1q_f16(y + i + 0 , y0);
+ vst1q_f16(y + i + 8 , y1);
+ vst1q_f16(y + i + 16, y2);
+ vst1q_f16(y + i + 24, y3);
+ }
+
+ // leftovers
+ for (int i = n32; i < n; ++i) {
+ GGML_ASSERT(false);
+ y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
+ }
+#elif defined(__AVX2__)
+ // AVX 256-bit
+ const int n32 = (n & ~31);
+
+ const __m256 v8 = _mm256_set1_ps(v);
+
+ __m256 x0, x1, x2, x3;
+ __m256 y0, y1, y2, y3;
+
+ for (int i = 0; i < n32; i += 32) {
+ y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 )));
+ y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 )));
+ y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16)));
+ y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24)));
+
+ x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 )));
+ x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 )));
+ x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16)));
+ x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24)));
+
+ y0 = _mm256_fmadd_ps(x0, v8, y0);
+ y1 = _mm256_fmadd_ps(x1, v8, y1);
+ y2 = _mm256_fmadd_ps(x2, v8, y2);
+ y3 = _mm256_fmadd_ps(x3, v8, y3);
+
+ _mm_storeu_si128((__m128i*)(y + i + 0 ), _mm256_cvtps_ph(y0, 0));
+ _mm_storeu_si128((__m128i*)(y + i + 8 ), _mm256_cvtps_ph(y1, 0));
+ _mm_storeu_si128((__m128i*)(y + i + 16), _mm256_cvtps_ph(y2, 0));
+ _mm_storeu_si128((__m128i*)(y + i + 24), _mm256_cvtps_ph(y3, 0));
+ }
+
+ // leftovers
+ for (int i = n32; i < n; ++i) {
+ GGML_ASSERT(false);
+ y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
+ }
+#else
+ for (int i = 0; i < n; ++i) {
+ y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
+ }
+#endif
+}
+
+inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
+inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrt(*s); }
+inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
+inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrt(x[i]); }
+inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
+inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
+inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
+inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
+
+const ggml_float GELU_COEF_A = 0.044715;
+const ggml_float SQRT_2_OVER_PI = 0.79788456080286535587989211986876;
+
+inline static void ggml_vec_gelu_f32 (const int n, float * y, const float * x) {
+ for (int i = 0; i < n; ++i) {
+ //y[i] = 0.5f*x[i]*(1.f + tanhf(SQRT_2_OVER_PI*(x[i] + 0.044715f*x[i]*x[i]*x[i])));
+ //0.5*x*(1+tf.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.pow(x, 3))))
+ const ggml_float xx = x[i];
+ y[i] = 0.5*xx*(1.0 + tanh(SQRT_2_OVER_PI*xx*(1.0 + GELU_COEF_A*xx*xx)));
+ }
+}
+
+inline static void ggml_vec_sum_f32 (const int n, float * s, const float * x) { ggml_float sum = 0.0; for (int i = 0; i < n; ++i) sum += x[i]; *s += sum; }
+inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { ggml_vec_norm_f32(n, s, x); *s = 1./(*s); }
+
+//
+// logging
+//
+
+#if (GGML_DEBUG >= 1)
+#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG(...)
+#endif
+
+#if (GGML_DEBUG >= 5)
+#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_5(...)
+#endif
+
+#if (GGML_DEBUG >= 10)
+#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_10(...)
+#endif
+
+#define GGML_PRINT(...) printf(__VA_ARGS__)
+
+//
+// data types
+//
+
+const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
+ sizeof(int8_t ),
+ sizeof(int16_t),
+ sizeof(int32_t),
+ sizeof(ggml_fp16_t),
+ sizeof(float ),
+};
+
+const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
+ "NONE",
+
+ "DUP",
+ "ADD",
+ "SUB",
+ "MUL",
+ "DIV",
+ "SQR",
+ "SQRT",
+ "SUM",
+ "MEAN",
+ "REPEAT",
+ "ABS",
+ "SGN",
+ "NEG",
+ "STEP",
+ "RELU",
+ "GELU",
+ "NORM",
+
+ "MUL_MAT",
+
+ "SCALE",
+ "CPY",
+ "RESHAPE",
+ "VIEW",
+ "PERMUTE",
+ "TRANSPOSE",
+ "GET_ROWS",
+ "DIAG_MASK_INF",
+ "SOFT_MAX",
+ "ROPE",
+ "CONV_1D_1S",
+ "CONV_1D_2S",
+};
+
+const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
+ "none",
+
+ "x",
+ "x+y",
+ "x-y",
+ "x*y",
+ "x/y",
+ "x^2",
+ "√x",
+ "Σx",
+ "Σx/n",
+ "repeat(x)",
+ "abs(x)",
+ "sgn(x)",
+ "-x",
+ "step(x)",
+ "relu(x)",
+ "gelu(x)",
+ "norm(x)",
+
+ "X*Y",
+
+ "x*v",
+ "x-\\>y",
+ "reshape(x)",
+ "view(x)",
+ "permute(x)",
+ "transpose(x)",
+ "get_rows(x)",
+ "diag_mask_inf(x)",
+ "soft_max(x)",
+ "rope(x)",
+ "conv_1d_1s(x)",
+ "conv_1d_2s(x)",
+};
+
+//
+// ggml object
+//
+
+struct ggml_object {
+ size_t offset;
+ size_t size;
+
+ struct ggml_object * next;
+
+ char padding[8];
+};
+
+const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
+
+static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
+static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
+
+//
+// ggml context
+//
+
+struct ggml_context {
+ size_t mem_size;
+ void * mem_buffer;
+ bool mem_buffer_owned;
+
+ int n_objects;
+
+ struct ggml_object * objects_begin;
+ struct ggml_object * objects_end;
+};
+
+struct ggml_context_container {
+ bool used;
+
+ struct ggml_context context;
+};
+
+//
+// compute types
+//
+
+enum ggml_task_type {
+ GGML_TASK_INIT = 0,
+ GGML_TASK_COMPUTE,
+ GGML_TASK_FINALIZE,
+};
+
+struct ggml_compute_params {
+ enum ggml_task_type type;
+
+ int ith, nth;
+
+ // work buffer for all threads
+ size_t wsize;
+ void * wdata;
+};
+
+//
+// ggml state
+//
+
+struct ggml_state {
+ struct ggml_context_container contexts[GGML_MAX_CONTEXTS];
+};
+
+// global state
+struct ggml_state g_state;
+
+////////////////////////////////////////////////////////////////////////////////
+
+void ggml_print_object(const struct ggml_object * obj) {
+ GGML_PRINT(" - ggml_object: offset = %zu, size = %zu, next = %p\n",
+ obj->offset, obj->size, (const void *) obj->next);
+}
+
+void ggml_print_objects(const struct ggml_context * ctx) {
+ struct ggml_object * obj = ctx->objects_begin;
+
+ GGML_PRINT("%s: objects in context %p:\n", __func__, (const void *) ctx);
+
+ while (obj != NULL) {
+ ggml_print_object(obj);
+ obj = obj->next;
+ }
+
+ GGML_PRINT("%s: --- end ---\n", __func__);
+}
+
+int ggml_nelements(const struct ggml_tensor * tensor) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
+}
+
+int ggml_nrows(const struct ggml_tensor * tensor) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
+}
+
+size_t ggml_nbytes(const struct ggml_tensor * tensor) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type];
+}
+
+size_t ggml_type_size(enum ggml_type type) {
+ return GGML_TYPE_SIZE[type];
+}
+
+size_t ggml_element_size(const struct ggml_tensor * tensor) {
+ return GGML_TYPE_SIZE[tensor->type];
+}
+
+bool ggml_is_scalar(const struct ggml_tensor * tensor) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;
+}
+
+bool ggml_is_vector(const struct ggml_tensor * tensor) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;
+}
+
+bool ggml_is_matrix(const struct ggml_tensor * tensor) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return tensor->ne[2] == 1 && tensor->ne[3] == 1;
+}
+
+bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return
+ (t0->ne[0] == t1->ne[0]) &&
+ (t0->ne[2] == t1->ne[2]) &&
+ (t0->ne[3] == t1->ne[3]);
+}
+
+bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return
+ tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] &&
+ tensor->nb[1] == tensor->nb[0]*tensor->ne[0] &&
+ tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
+ tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
+}
+
+bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return
+ tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] &&
+ tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
+ tensor->nb[3] == tensor->nb[2]*tensor->ne[2];;
+}
+
+bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return
+ (t0->ne[0] == t1->ne[0] ) &&
+ (t0->ne[1] == t1->ne[1] ) &&
+ (t0->ne[2] == t1->ne[2] ) &&
+ (t0->ne[3] == t1->ne[3] );
+}
+
+// check if t1 can be represented as a repeatition of t0
+bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return
+ (t1->ne[0]%t0->ne[0] == 0) &&
+ (t1->ne[1]%t0->ne[1] == 0) &&
+ (t1->ne[2]%t0->ne[2] == 0) &&
+ (t1->ne[3]%t0->ne[3] == 0);
+}
+
+int ggml_up32(int n) {
+ return (n + 31) & ~31;
+}
+
+int ggml_up64(int n) {
+ return (n + 63) & ~63;
+}
+
+// assert that pointer is aligned to GGML_MEM_ALIGN
+#define ggml_assert_aligned(ptr) \
+ assert(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0)
+
+////////////////////////////////////////////////////////////////////////////////
+
+struct ggml_context * ggml_init(struct ggml_init_params params) {
+ // find non-used context in g_state
+ struct ggml_context * ctx = NULL;
+
+ static bool first_time = true;
+ if (first_time) {
+ for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
+ g_state.contexts[i].used = false;
+ }
+ first_time = false;
+ }
+
+ for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
+ if (!g_state.contexts[i].used) {
+ g_state.contexts[i].used = true;
+ ctx = &g_state.contexts[i].context;
+
+ GGML_PRINT_DEBUG("%s: found unused context %d\n", __func__, i);
+ break;
+ }
+ }
+
+ if (ctx == NULL) {
+ GGML_PRINT_DEBUG("%s\n", "ggml_init: no unused context found");
+ return NULL;
+ }
+
+ *ctx = (struct ggml_context) {
+ .mem_size = params.mem_size,
+ .mem_buffer = params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
+ .mem_buffer_owned = params.mem_buffer ? false : true,
+ .n_objects = 0,
+ .objects_begin = NULL,
+ .objects_end = NULL,
+ };
+
+ ggml_assert_aligned(ctx->mem_buffer);
+
+ return ctx;
+}
+
+void ggml_free(struct ggml_context * ctx) {
+ for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
+ if (&g_state.contexts[i].context == ctx) {
+ g_state.contexts[i].used = false;
+
+ GGML_PRINT_DEBUG("ggml_free: context %d with %d objects has been freed. memory used = %zu\n",
+ i, ctx->n_objects, ctx->objects_end->offset + ctx->objects_end->size);
+
+ if (ctx->mem_buffer_owned) {
+ free(ctx->mem_buffer);
+ }
+
+ return;
+ }
+ }
+
+ GGML_PRINT_DEBUG("%s: context not found\n", __func__);
+}
+
+size_t ggml_used_mem(const struct ggml_context * ctx) {
+ return ctx->objects_end->offset + ctx->objects_end->size;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+struct ggml_tensor * ggml_new_tensor_impl(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int n_dims,
+ const int* ne,
+ void* data) {
+ // always insert objects at the end of the context's memory pool
+ struct ggml_object * obj_cur = ctx->objects_end;
+
+ const size_t cur_offset = obj_cur == NULL ? 0 : obj_cur->offset;
+ const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size;
+ const size_t cur_end = cur_offset + cur_size;
+
+ size_t size_needed = 0;
+
+ if (data == NULL) {
+ size_needed += GGML_TYPE_SIZE[type];
+ for (int i = 0; i < n_dims; i++) {
+ size_needed *= ne[i];
+ }
+ // align to GGML_MEM_ALIGN
+ size_needed = ((size_needed + GGML_MEM_ALIGN - 1)/GGML_MEM_ALIGN)*GGML_MEM_ALIGN;
+
+ }
+ size_needed += sizeof(struct ggml_tensor);
+
+ if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
+ GGML_PRINT("%s: not enough space in the context's memory pool\n", __func__);
+ assert(false);
+ return NULL;
+ }
+
+ char * const mem_buffer = ctx->mem_buffer;
+
+ struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end);
+
+ *obj_new = (struct ggml_object) {
+ .offset = cur_end + GGML_OBJECT_SIZE,
+ .size = size_needed,
+ .next = NULL,
+ };
+
+ if (obj_cur != NULL) {
+ obj_cur->next = obj_new;
+ } else {
+ // this is the first object in this context
+ ctx->objects_begin = obj_new;
+ }
+
+ ctx->objects_end = obj_new;
+
+ //GGML_PRINT_DEBUG("%s: inserted new object at %zu\n", __func__, cur_end);
+
+ struct ggml_tensor * const result = (struct ggml_tensor *)(mem_buffer + obj_new->offset);
+
+ ggml_assert_aligned(result);
+
+ *result = (struct ggml_tensor) {
+ /*.type =*/ type,
+ /*.n_dims =*/ n_dims,
+ /*.ne =*/ { 1, 1, 1, 1 },
+ /*.nb =*/ { 0, 0, 0, 0 },
+ /*.op =*/ GGML_OP_NONE,
+ /*.is_param =*/ false,
+ /*.grad =*/ NULL,
+ /*.src0 =*/ NULL,
+ /*.src1 =*/ NULL,
+ /*.n_tasks =*/ 0,
+ /*.perf_runs =*/ 0,
+ /*.perf_cycles =*/ 0,
+ /*.perf_time_us =*/ 0,
+ /*.data =*/ data == NULL ? (void *)(result + 1) : data,
+ /*.pad =*/ { 0 },
+ };
+
+ ggml_assert_aligned(result->data);
+
+ for (int i = 0; i < n_dims; i++) {
+ result->ne[i] = ne[i];
+ }
+
+ result->nb[0] = GGML_TYPE_SIZE[type];
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
+ result->nb[i] = result->nb[i - 1]*result->ne[i - 1];
+ }
+
+ ctx->n_objects++;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_new_tensor(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int n_dims,
+ const int* ne) {
+ return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL);
+}
+
+struct ggml_tensor * ggml_new_tensor_1d(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int ne0) {
+ return ggml_new_tensor(ctx, type, 1, &ne0);
+}
+
+struct ggml_tensor * ggml_new_tensor_2d(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int ne0,
+ int ne1) {
+ const int ne[2] = { ne0, ne1 };
+ return ggml_new_tensor(ctx, type, 2, ne);
+}
+
+struct ggml_tensor * ggml_new_tensor_3d(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int ne0,
+ int ne1,
+ int ne2) {
+ const int ne[3] = { ne0, ne1, ne2 };
+ return ggml_new_tensor(ctx, type, 3, ne);
+}
+
+struct ggml_tensor * ggml_new_tensor_4d(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int ne0,
+ int ne1,
+ int ne2,
+ int ne3) {
+ const int ne[4] = { ne0, ne1, ne2, ne3 };
+ return ggml_new_tensor(ctx, type, 4, ne);
+}
+
+struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) {
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+
+ ggml_set_f32(result, value);
+
+ return result;
+}
+
+struct ggml_tensor * ggml_dup_tensor(struct ggml_context * ctx, const struct ggml_tensor * src) {
+ return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, NULL);
+}
+
+struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
+ memset(tensor->data, 0, ggml_nbytes(tensor));
+ return tensor;
+}
+
+struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
+ const int n = ggml_nrows(tensor);
+ const int nc = tensor->ne[0];
+ const size_t n1 = tensor->nb[1];
+
+ char * const data = tensor->data;
+
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ {
+ assert(tensor->nb[0] == sizeof(int8_t));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value);
+ }
+ } break;
+ case GGML_TYPE_I16:
+ {
+ assert(tensor->nb[0] == sizeof(int16_t));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value);
+ }
+ } break;
+ case GGML_TYPE_I32:
+ {
+ assert(tensor->nb[0] == sizeof(int32_t));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value);
+ }
+ } break;
+ case GGML_TYPE_F16:
+ {
+ assert(tensor->nb[0] == sizeof(ggml_fp16_t));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value);
+ }
+ } break;
+ case GGML_TYPE_F32:
+ {
+ assert(tensor->nb[0] == sizeof(float));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
+ }
+ } break;
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+
+ return tensor;
+}
+
+float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ {
+ assert(tensor->nb[0] == sizeof(int8_t));
+ return ((int8_t *)(tensor->data))[i];
+ } break;
+ case GGML_TYPE_I16:
+ {
+ assert(tensor->nb[0] == sizeof(int16_t));
+ return ((int16_t *)(tensor->data))[i];
+ } break;
+ case GGML_TYPE_I32:
+ {
+ assert(tensor->nb[0] == sizeof(int32_t));
+ return ((int32_t *)(tensor->data))[i];
+ } break;
+ case GGML_TYPE_F16:
+ {
+ assert(tensor->nb[0] == sizeof(ggml_fp16_t));
+ return ggml_fp16_to_fp32(((ggml_fp16_t *)(tensor->data))[i]);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ assert(tensor->nb[0] == sizeof(float));
+ return ((float *)(tensor->data))[i];
+ } break;
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+
+ assert(false);
+ return 0.0f;
+}
+
+void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ {
+ assert(tensor->nb[0] == sizeof(int8_t));
+ ((int8_t *)(tensor->data))[i] = value;
+ } break;
+ case GGML_TYPE_I16:
+ {
+ assert(tensor->nb[0] == sizeof(int16_t));
+ ((int16_t *)(tensor->data))[i] = value;
+ } break;
+ case GGML_TYPE_I32:
+ {
+ assert(tensor->nb[0] == sizeof(int32_t));
+ ((int32_t *)(tensor->data))[i] = value;
+ } break;
+ case GGML_TYPE_F16:
+ {
+ assert(tensor->nb[0] == sizeof(ggml_fp16_t));
+ ((ggml_fp16_t *)(tensor->data))[i] = ggml_fp32_to_fp16(value);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ assert(tensor->nb[0] == sizeof(float));
+ ((float *)(tensor->data))[i] = value;
+ } break;
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+void * ggml_get_data(const struct ggml_tensor * tensor) {
+ return tensor->data;
+}
+
+float * ggml_get_data_f32(const struct ggml_tensor * tensor) {
+ assert(tensor->type == GGML_TYPE_F32);
+ return (float *)(tensor->data);
+}
+
+struct ggml_tensor * ggml_view_tensor(
+ struct ggml_context * ctx,
+ const struct ggml_tensor * src) {
+ return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+// ggml_dup
+
+struct ggml_tensor * ggml_dup_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_DUP;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_dup(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_dup_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_dup_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_dup_impl(ctx, a, true);
+}
+
+// ggml_add
+
+struct ggml_tensor * ggml_add_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ bool inplace) {
+ assert(ggml_are_same_shape(a, b));
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad || b->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_ADD;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_add(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_add_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_add_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_add_impl(ctx, a, b, true);
+}
+
+// ggml_sub
+
+struct ggml_tensor * ggml_sub_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ bool inplace) {
+ assert(ggml_are_same_shape(a, b));
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad || b->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_SUB;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_sub(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_sub_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_sub_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_sub_impl(ctx, a, b, true);
+}
+
+// ggml_mul
+
+struct ggml_tensor * ggml_mul_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ bool inplace) {
+ assert(ggml_are_same_shape(a, b));
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad || b->grad)) {
+ is_node = true;
+ }
+
+ if (inplace) {
+ assert(is_node == false);
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_MUL;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_mul(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_mul_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_mul_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_mul_impl(ctx, a, b, true);
+}
+
+// ggml_div
+
+struct ggml_tensor * ggml_div_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ bool inplace) {
+ assert(ggml_are_same_shape(a, b));
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad || b->grad)) {
+ is_node = true;
+ }
+
+ if (inplace) {
+ assert(is_node == false);
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_DIV;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_div(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_div_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_div_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_div_impl(ctx, a, b, true);
+}
+
+// ggml_sqr
+
+struct ggml_tensor * ggml_sqr_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_SQR;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_sqr(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_sqr_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_sqr_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_sqr_impl(ctx, a, true);
+}
+
+// ggml_sqrt
+
+struct ggml_tensor * ggml_sqrt_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_SQRT;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_sqrt(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_sqrt_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_sqrt_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_sqrt_impl(ctx, a, true);
+}
+
+// ggml_sum
+
+struct ggml_tensor * ggml_sum(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
+
+ result->op = GGML_OP_SUM;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+// ggml_mean
+
+struct ggml_tensor * ggml_mean(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ bool is_node = false;
+
+ if (a->grad) {
+ assert(false); // TODO: implement
+ is_node = true;
+ }
+
+ int ne[GGML_MAX_DIMS] = { 1, a->ne[1], a->ne[2], a->ne[3] };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, ne);
+
+ result->op = GGML_OP_MEAN;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+// ggml_repeat
+
+struct ggml_tensor * ggml_repeat(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ assert(ggml_can_repeat(a, b));
+
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = true;
+ }
+
+ if (ggml_are_same_shape(a, b) && !is_node) {
+ return a;
+ }
+
+ struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
+
+ result->op = GGML_OP_REPEAT;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+// ggml_abs
+
+struct ggml_tensor * ggml_abs_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_ABS;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_abs(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_abs_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_abs_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_abs_impl(ctx, a, true);
+}
+
+
+// ggml_sgn
+
+struct ggml_tensor * ggml_sgn_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_SGN;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_sgn(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_sgn_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_sgn_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_sgn_impl(ctx, a, true);
+}
+
+// ggml_neg
+
+struct ggml_tensor * ggml_neg_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_NEG;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_neg(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_neg_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_neg_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_neg_impl(ctx, a, true);
+}
+
+// ggml_step
+
+struct ggml_tensor * ggml_step_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_STEP;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_step(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_step_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_step_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_step_impl(ctx, a, true);
+}
+
+// ggml_relu
+
+struct ggml_tensor * ggml_relu_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_RELU;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_relu(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_relu_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_relu_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_relu_impl(ctx, a, true);
+}
+
+// ggml_gelu
+
+struct ggml_tensor * ggml_gelu_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_GELU;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_gelu(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_gelu_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_gelu_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_gelu_impl(ctx, a, true);
+}
+
+// ggml_norm
+
+struct ggml_tensor * ggml_norm_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_NORM;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL; // TODO: maybe store epsilon here?
+
+ return result;
+}
+
+struct ggml_tensor * ggml_norm(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_norm_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_norm_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_norm_impl(ctx, a, true);
+}
+
+// ggml_mul_mat
+
+struct ggml_tensor * ggml_mul_mat(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ assert(ggml_can_mul_mat(a, b));
+
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ is_node = true;
+ }
+
+ const int ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
+
+ result->op = GGML_OP_MUL_MAT;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = b;
+
+ return result;
+}
+
+// ggml_scale
+
+struct ggml_tensor * ggml_scale_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ bool inplace) {
+ assert(ggml_is_scalar(b));
+ assert(ggml_is_padded_1d(a));
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad || b->grad)) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ // TODO: when implement backward, fix this:
+ //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+ result->op = GGML_OP_SCALE;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_scale(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_scale_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_scale_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_scale_impl(ctx, a, b, true);
+}
+
+// ggml_cpy
+
+struct ggml_tensor * ggml_cpy_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ bool inplace) {
+ assert(ggml_nelements(a) == ggml_nelements(b));
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad || b->grad)) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ // make a view of the destination
+ struct ggml_tensor * result = ggml_view_tensor(ctx, b);
+
+ result->op = GGML_OP_CPY;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_cpy(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_cpy_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_cpy_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_cpy_impl(ctx, a, b, true);
+}
+
+// ggml_reshape
+
+struct ggml_tensor * ggml_reshape(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ assert(ggml_is_contiguous(a));
+ assert(ggml_is_contiguous(b));
+ assert(ggml_nelements(a) == ggml_nelements(b));
+
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a->data);
+
+ result->op = GGML_OP_RESHAPE;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_reshape_2d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int ne0,
+ int ne1) {
+ assert(ggml_is_contiguous(a));
+ assert(ggml_nelements(a) == ne0*ne1);
+
+ bool is_node = false;
+
+ if (a->grad) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ const int ne[2] = { ne0, ne1 };
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a->data);
+
+ result->op = GGML_OP_RESHAPE;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_reshape_3d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int ne0,
+ int ne1,
+ int ne2) {
+ assert(ggml_is_contiguous(a));
+ assert(ggml_nelements(a) == ne0*ne1*ne2);
+
+ bool is_node = false;
+
+ if (a->grad) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ const int ne[3] = { ne0, ne1, ne2 };
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a->data);
+
+ result->op = GGML_OP_RESHAPE;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+// ggml_view_1d
+
+struct ggml_tensor * ggml_view_1d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int ne0,
+ size_t offset) {
+ if (a->grad) {
+ assert(false); // gradient propagation is not supported
+ }
+
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset);
+
+ result->op = GGML_OP_VIEW;
+ result->grad = NULL;
+ result->src0 = a;
+ result->src1 = NULL; // TODO: maybe store the offset here?
+
+ return result;
+}
+
+// ggml_view_2d
+
+struct ggml_tensor * ggml_view_2d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int ne0,
+ int ne1,
+ size_t nb1,
+ size_t offset) {
+ if (a->grad) {
+ assert(false); // gradient propagation is not supported
+ }
+
+ const int ne[GGML_MAX_DIMS] = { ne0, ne1, 1, 1 };
+
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset);
+
+ result->nb[1] = nb1;
+ result->nb[2] = result->nb[1]*ne1;
+ result->nb[3] = result->nb[2];
+
+ result->op = GGML_OP_VIEW;
+ result->grad = NULL;
+ result->src0 = a;
+ result->src1 = NULL; // TODO: maybe store the offset here?
+
+ return result;
+}
+
+// ggml_permute
+
+struct ggml_tensor * ggml_permute(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int axis0,
+ int axis1,
+ int axis2,
+ int axis3) {
+ assert(axis0 >= 0 && axis0 < GGML_MAX_DIMS);
+ assert(axis1 >= 0 && axis1 < GGML_MAX_DIMS);
+ assert(axis2 >= 0 && axis2 < GGML_MAX_DIMS);
+ assert(axis3 >= 0 && axis3 < GGML_MAX_DIMS);
+
+ assert(axis0 != axis1);
+ assert(axis0 != axis2);
+ assert(axis0 != axis3);
+ assert(axis1 != axis2);
+ assert(axis1 != axis3);
+ assert(axis2 != axis3);
+
+ bool is_node = false;
+
+ if (a->grad) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+ int ne[GGML_MAX_DIMS];
+ int nb[GGML_MAX_DIMS];
+
+ ne[axis0] = a->ne[0];
+ ne[axis1] = a->ne[1];
+ ne[axis2] = a->ne[2];
+ ne[axis3] = a->ne[3];
+
+ nb[axis0] = a->nb[0];
+ nb[axis1] = a->nb[1];
+ nb[axis2] = a->nb[2];
+ nb[axis3] = a->nb[3];
+
+ result->ne[0] = ne[0];
+ result->ne[1] = ne[1];
+ result->ne[2] = ne[2];
+ result->ne[3] = ne[3];
+
+ result->nb[0] = nb[0];
+ result->nb[1] = nb[1];
+ result->nb[2] = nb[2];
+ result->nb[3] = nb[3];
+
+ result->op = GGML_OP_PERMUTE;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL; // TODO: maybe store the permutation here?
+
+ return result;
+}
+
+// ggml_transpose
+
+struct ggml_tensor * ggml_transpose(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ bool is_node = false;
+
+ if (a->grad) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+ result->ne[0] = a->ne[1];
+ result->ne[1] = a->ne[0];
+
+ result->nb[0] = a->nb[1];
+ result->nb[1] = a->nb[0];
+
+ result->op = GGML_OP_TRANSPOSE;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+// ggml_get_rows
+
+struct ggml_tensor * ggml_get_rows(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ assert(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32);
+
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ // TODO: implement non F32 return
+ //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
+ struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0]);
+
+ result->op = GGML_OP_GET_ROWS;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = b;
+
+ return result;
+}
+
+// ggml_diag_mask_inf
+
+struct ggml_tensor * ggml_diag_mask_inf(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_past) {
+ bool is_node = false;
+
+ if (a->grad) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ // TODO: when implement backward, fix this:
+ //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
+ ((int32_t *) b->data)[0] = n_past;
+
+ result->op = GGML_OP_DIAG_MASK_INF;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = b;
+
+ return result;
+}
+
+// ggml_soft_max
+
+struct ggml_tensor * ggml_soft_max(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ bool is_node = false;
+
+ if (a->grad) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ // TODO: when implement backward, fix this:
+ //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+ result->op = GGML_OP_SOFT_MAX;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+// ggml_rope
+
+struct ggml_tensor * ggml_rope(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_past,
+ int n_dims,
+ int mode) {
+ assert(n_past >= 0);
+ bool is_node = false;
+
+ if (a->grad) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ // TODO: when implement backward, fix this:
+ //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
+ ((int32_t *) b->data)[0] = n_past;
+ ((int32_t *) b->data)[1] = n_dims;
+ ((int32_t *) b->data)[2] = mode;
+
+ result->op = GGML_OP_ROPE;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = b;
+
+ return result;
+}
+
+// ggml_conv_1d_1s
+
+struct ggml_tensor * ggml_conv_1d_1s(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ assert(ggml_is_matrix(b));
+ assert(a->ne[1] == b->ne[1]);
+ assert(a->ne[3] == 1);
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ const int ne[4] = { b->ne[0], a->ne[2], 1, 1, };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
+
+ result->op = GGML_OP_CONV_1D_1S;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = b;
+
+ return result;
+}
+
+// ggml_conv_1d_2s
+
+struct ggml_tensor * ggml_conv_1d_2s(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ assert(ggml_is_matrix(b));
+ assert(a->ne[1] == b->ne[1]);
+ assert(a->ne[3] == 1);
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ const int ne[4] = { b->ne[0]/2, a->ne[2], 1, 1, };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
+
+ result->op = GGML_OP_CONV_1D_2S;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = b;
+
+ return result;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+void ggml_set_param(
+ struct ggml_context * ctx,
+ struct ggml_tensor * tensor) {
+ tensor->is_param = true;
+
+ assert(tensor->grad == NULL);
+ tensor->grad = ggml_dup_tensor(ctx, tensor);
+}
+
+// ggml_compute_forward_dup
+
+void ggml_compute_forward_dup_f16(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_is_contiguous(dst));
+ assert(ggml_nelements(dst) == ggml_nelements(src0));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ //const int ne00 = src0->ne[0];
+ //const int ne01 = src0->ne[1];
+ //const int ne02 = src0->ne[2];
+ //const int ne03 = src0->ne[3];
+
+ //const size_t nb00 = src0->nb[0];
+ //const size_t nb01 = src0->nb[1];
+ //const size_t nb02 = src0->nb[2];
+ //const size_t nb03 = src0->nb[3];
+
+ if (ggml_is_contiguous(src0) && src0->type == dst->type) {
+ memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
+ return;
+ }
+
+ GGML_ASSERT(false); // TODO: implement
+}
+
+void ggml_compute_forward_dup_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ GGML_ASSERT(params->ith == 0);
+ GGML_ASSERT(ggml_is_contiguous(dst));
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int ne00 = src0->ne[0];
+ const int ne01 = src0->ne[1];
+ const int ne02 = src0->ne[2];
+ const int ne03 = src0->ne[3];
+
+ const size_t nb00 = src0->nb[0];
+ const size_t nb01 = src0->nb[1];
+ const size_t nb02 = src0->nb[2];
+ const size_t nb03 = src0->nb[3];
+
+ if (ggml_is_contiguous(src0) && src0->type == dst->type) {
+ memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
+ return;
+ }
+
+ if (src0->nb[0] == sizeof(float)) {
+ if (dst->type == GGML_TYPE_F32) {
+ int id = 0;
+ const size_t rs = ne00*nb00;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+ char * dst_ptr = (char *) dst->data + id*rs;
+
+ memcpy(dst_ptr, src0_ptr, rs);
+
+ id++;
+ }
+ }
+ }
+ } else if (dst->type == GGML_TYPE_F16) {
+ int id = 0;
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ for (int i00 = 0; i00 < ne00; i00++) {
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+ dst_ptr[id] = ggml_fp32_to_fp16(*src0_ptr);
+ id++;
+ }
+ }
+ }
+ }
+ } else {
+ GGML_ASSERT(false); // TODO: implement
+ }
+ } else {
+ printf("%s: this is not optimal - fix me\n", __func__);
+
+ if (dst->type == GGML_TYPE_F32) {
+ int id = 0;
+ float * dst_ptr = (float *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ for (int i00 = 0; i00 < ne00; i00++) {
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+ dst_ptr[id] = *src0_ptr;
+ id++;
+ }
+ }
+ }
+ }
+ } else if (dst->type == GGML_TYPE_F16) {
+ int id = 0;
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ for (int i00 = 0; i00 < ne00; i00++) {
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+ dst_ptr[id] = ggml_fp32_to_fp16(*src0_ptr);
+ id++;
+ }
+ }
+ }
+ }
+ } else {
+ GGML_ASSERT(false); // TODO: implement
+ }
+ }
+}
+
+void ggml_compute_forward_dup(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F16:
+ {
+ ggml_compute_forward_dup_f16(params, src0, dst);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_dup_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_COUNT:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_add
+
+void ggml_compute_forward_add_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ GGML_ASSERT(params->ith == 0);
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ const size_t nb00 = src0->nb[0];
+ const size_t nb01 = src0->nb[1];
+
+ const size_t nb10 = src1->nb[0];
+ const size_t nb11 = src1->nb[1];
+
+ const size_t nb0 = dst->nb[0];
+ const size_t nb1 = dst->nb[1];
+
+ GGML_ASSERT( nb0 == sizeof(float));
+ GGML_ASSERT(nb00 == sizeof(float));
+
+ if (nb10 == sizeof(float)) {
+ for (int j = 0; j < n; j++) {
+ ggml_vec_add_f32(nc,
+ (float *) ((char *) dst->data + j*nb1),
+ (float *) ((char *) src0->data + j*nb01),
+ (float *) ((char *) src1->data + j*nb11));
+ }
+ } else {
+ // src1 is not contiguous
+ for (int j = 0; j < n; j++) {
+ float * dst_ptr = (float *) ((char *) dst->data + j*nb1);
+ float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
+ for (int i = 0; i < nc; i++) {
+ float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
+
+ dst_ptr[i] = src0_ptr[i] + *src1_ptr;
+ }
+ }
+ }
+}
+
+void ggml_compute_forward_add(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_add_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_sub
+
+void ggml_compute_forward_sub_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert( dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+ assert(src1->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_sub_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])),
+ (float *) ((char *) src1->data + i*(src1->nb[1])));
+ }
+}
+
+void ggml_compute_forward_sub(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_sub_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_mul
+
+void ggml_compute_forward_mul_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert( dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+ assert(src1->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_mul_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])),
+ (float *) ((char *) src1->data + i*(src1->nb[1])));
+ }
+}
+
+void ggml_compute_forward_mul(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_mul_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_div
+
+void ggml_compute_forward_div_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert( dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+ assert(src1->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_div_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])),
+ (float *) ((char *) src1->data + i*(src1->nb[1])));
+ }
+}
+
+void ggml_compute_forward_div(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_div_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_sqr
+
+void ggml_compute_forward_sqr_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert( dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_sqr_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+void ggml_compute_forward_sqr(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_sqr_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_sqrt
+
+void ggml_compute_forward_sqrt_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert( dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_sqrt_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+void ggml_compute_forward_sqrt(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_sqrt_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_sum
+
+void ggml_compute_forward_sum_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_is_scalar(dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ assert(ggml_is_scalar(dst));
+ assert(src0->nb[0] == sizeof(float));
+
+ *(float *) (dst->data) = 0.0f;
+
+ const int ne00 = src0->ne[0];
+ const int ne01 = src0->ne[1];
+ const int ne02 = src0->ne[2];
+ const int ne03 = src0->ne[3];
+
+ const size_t nb01 = src0->nb[1];
+ const size_t nb02 = src0->nb[2];
+ const size_t nb03 = src0->nb[3];
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ ggml_vec_sum_f32(ne00,
+ (float *) (dst->data),
+ (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
+ }
+ }
+ }
+}
+
+void ggml_compute_forward_sum(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_sum_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_mean
+
+void ggml_compute_forward_mean_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ assert(src0->nb[0] == sizeof(float));
+
+ const int ne00 = src0->ne[0];
+ const int ne01 = src0->ne[1];
+ const int ne02 = src0->ne[2];
+ const int ne03 = src0->ne[3];
+
+ const size_t nb01 = src0->nb[1];
+ const size_t nb02 = src0->nb[2];
+ const size_t nb03 = src0->nb[3];
+
+ const int ne0 = dst->ne[0];
+ const int ne1 = dst->ne[1];
+ const int ne2 = dst->ne[2];
+ const int ne3 = dst->ne[3];
+
+ assert(ne0 == 1);
+ assert(ne1 == ne01);
+ assert(ne2 == ne02);
+ assert(ne3 == ne03);
+
+ UNUSED(ne0);
+ UNUSED(ne1);
+ UNUSED(ne2);
+ UNUSED(ne3);
+
+ const size_t nb1 = dst->nb[1];
+ const size_t nb2 = dst->nb[2];
+ const size_t nb3 = dst->nb[3];
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) = 0.0f;
+
+ ggml_vec_sum_f32(ne00,
+ (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
+ (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
+
+ *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00;
+ }
+ }
+ }
+}
+
+void ggml_compute_forward_mean(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_mean_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_repeat
+
+void ggml_compute_forward_repeat_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_can_repeat(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ // TODO: implement support for rank > 2 tensors
+ assert(src0->ne[2] == 1);
+ assert(src0->ne[3] == 1);
+ assert( dst->ne[2] == 1);
+ assert( dst->ne[3] == 1);
+
+ const int nc = dst->ne[0];
+ const int nr = dst->ne[1];
+ const int nc0 = src0->ne[0];
+ const int nr0 = src0->ne[1];
+ const int ncr = nc/nc0; // guaranteed to be an integer due to the check in ggml_can_repeat
+ const int nrr = nr/nr0; // guaranteed to be an integer due to the check in ggml_can_repeat
+
+ // TODO: support for transposed / permuted tensors
+ assert( dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+
+ // TODO: maybe this is not optimal?
+ for (int i = 0; i < nrr; i++) {
+ for (int j = 0; j < ncr; j++) {
+ for (int k = 0; k < nr0; k++) {
+ ggml_vec_cpy_f32(nc0,
+ (float *) ((char *) dst->data + (i*nr0 + k)*( dst->nb[1]) + j*nc0*( dst->nb[0])),
+ (float *) ((char *) src0->data + ( k)*(src0->nb[1])));
+ }
+ }
+ }
+}
+
+void ggml_compute_forward_repeat(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_repeat_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_abs
+
+void ggml_compute_forward_abs_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert(dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_abs_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+void ggml_compute_forward_abs(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_abs_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_sgn
+
+void ggml_compute_forward_sgn_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert(dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_sgn_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+void ggml_compute_forward_sgn(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_sgn_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_neg
+
+void ggml_compute_forward_neg_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert(dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_neg_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+void ggml_compute_forward_neg(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_neg_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_step
+
+void ggml_compute_forward_step_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert(dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_step_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+void ggml_compute_forward_step(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_step_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_relu
+
+void ggml_compute_forward_relu_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert(dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_relu_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+void ggml_compute_forward_relu(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_relu_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_gelu
+
+void ggml_compute_forward_gelu_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nc = src0->ne[0];
+ const int nr = ggml_nrows(src0);
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ ggml_vec_gelu_f32(nc,
+ (float *) ((char *) dst->data + i1*( dst->nb[1])),
+ (float *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+ for (int k = 0; k < nc; k++) {
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+ UNUSED(x);
+ assert(!isnan(x));
+ assert(!isinf(x));
+ }
+#endif
+ }
+}
+
+void ggml_compute_forward_gelu(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_gelu_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_norm
+
+void ggml_compute_forward_norm_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ assert(src0->nb[0] == sizeof(float));
+
+ const int ne00 = src0->ne[0];
+ const int ne01 = src0->ne[1];
+ const int ne02 = src0->ne[2];
+ const int ne03 = src0->ne[3];
+
+ const size_t nb01 = src0->nb[1];
+ const size_t nb02 = src0->nb[2];
+ const size_t nb03 = src0->nb[3];
+
+ const size_t nb1 = dst->nb[1];
+ const size_t nb2 = dst->nb[2];
+ const size_t nb3 = dst->nb[3];
+
+ const ggml_float eps = 1e-5f; // TODO: make this a parameter
+
+ // TODO: optimize
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+ ggml_float mean = 0.0;
+ for (int i00 = 0; i00 < ne00; i00++) {
+ mean += x[i00];
+ }
+
+ mean /= ne00;
+
+ float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+ ggml_float sum2 = 0.0;
+ for (int i00 = 0; i00 < ne00; i00++) {
+ ggml_float v = x[i00] - mean;
+ y[i00] = v;
+ sum2 += v*v;
+ }
+
+ const float scale = 1.0/sqrt(sum2/ne00 + eps);
+
+ ggml_vec_scale_f32(ne00, y, scale);
+ }
+ }
+ }
+}
+
+void ggml_compute_forward_norm(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_norm_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_mul_mat
+
+void ggml_compute_forward_mul_mat_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ int64_t t0 = ggml_perf_time_us();
+ UNUSED(t0);
+
+ const int ne00 = src0->ne[0];
+ const int ne01 = src0->ne[1];
+ const int ne02 = src0->ne[2];
+ const int ne03 = src0->ne[3];
+
+ const int ne10 = src1->ne[0];
+ const int ne11 = src1->ne[1];
+ const int ne12 = src1->ne[2];
+ const int ne13 = src1->ne[3];
+
+ const int ne0 = dst->ne[0];
+ const int ne1 = dst->ne[1];
+ const int ne2 = dst->ne[2];
+ const int ne3 = dst->ne[3];
+ const int ne = ne0*ne1*ne2*ne3;
+
+ const int nb00 = src0->nb[0];
+ const int nb01 = src0->nb[1];
+ const int nb02 = src0->nb[2];
+ const int nb03 = src0->nb[3];
+
+ const int nb10 = src1->nb[0];
+ const int nb11 = src1->nb[1];
+ const int nb12 = src1->nb[2];
+ const int nb13 = src1->nb[3];
+
+ const int nb0 = dst->nb[0];
+ const int nb1 = dst->nb[1];
+ const int nb2 = dst->nb[2];
+ const int nb3 = dst->nb[3];
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ assert(ne02 == ne12);
+ assert(ne03 == ne13);
+ assert(ne2 == ne12);
+ assert(ne3 == ne13);
+
+ // TODO: we don't support permuted src0
+ assert(nb00 == sizeof(float) || nb01 == sizeof(float));
+
+ // dst cannot be transposed or permuted
+ assert(nb0 == sizeof(float));
+ assert(nb0 <= nb1);
+ assert(nb1 <= nb2);
+ assert(nb2 <= nb3);
+
+ assert(ne0 == ne01);
+ assert(ne1 == ne11);
+ assert(ne2 == ne02);
+ assert(ne3 == ne03);
+
+ // nb01 >= nb00 - src0 is not transposed
+ // compute by src0 rows
+ //
+ // nb00 < nb01 - src0 is transposed
+ // compute by src0 columns
+
+ if (params->type == GGML_TASK_INIT) {
+ if (nb01 >= nb00) {
+ return;
+ }
+
+ // TODO: fix this memset (wsize is overestimated)
+ memset(params->wdata, 0, params->wsize);
+ return;
+ }
+
+ if (params->type == GGML_TASK_FINALIZE) {
+ if (nb01 >= nb00) {
+ return;
+ }
+
+ // TODO: fix this memset (wsize is overestimated)
+ //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth);
+
+ float * const wdata = params->wdata;
+
+ // cols per thread
+ const int dc = (ne + nth - 1)/nth;
+
+ // col range for this thread
+ const int ic0 = dc*ith;
+ const int ic1 = MIN(ic0 + dc, ne);
+
+ ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0);
+
+ for (int k = 1; k < nth; k++) {
+ ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0);
+ }
+
+ return;
+ }
+
+//#ifdef GGML_USE_ACCELERATE
+// // try to use BLAS
+//
+// if (nb01 >= nb00 && ne0 > 1024 && ne1 > 1024) {
+// if (params->ith != 0) return;
+// printf("XXXXXXXX\n");
+//
+// GGML_ASSERT(ggml_is_contiguous(src0));
+// GGML_ASSERT(ggml_is_contiguous(src1));
+//
+// printf("ne00 = %d, ne01 = %d, ne02 = %d, ne03 = %d\n", ne00, ne01, ne02, ne03);
+// printf("ne10 = %d, ne11 = %d, ne12 = %d, ne13 = %d\n", ne10, ne11, ne12, ne13);
+// printf("ne0 = %d, ne1 = %d, ne2 = %d, ne3 = %d\n", ne0, ne1, ne2, ne3);
+//
+// printf("nb00 = %d, nb01 = %d, nb02 = %d, nb03 = %d\n", nb00, nb01, nb02, nb03);
+// printf("nb10 = %d, nb11 = %d, nb12 = %d, nb13 = %d\n", nb10, nb11, nb12, nb13);
+// printf("nb0 = %d, nb1 = %d, nb2 = %d, nb3 = %d\n", nb0, nb1, nb2, nb3);
+//
+// float * const wdata = params->wdata;
+//
+// int64_t tsum = 0.0;
+// for (int i03 = 0; i03 < ne03; i03++) {
+// for (int i02 = 0; i02 < ne02; i02++) {
+// const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
+// const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
+// float * z = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+//
+// // transpose src1
+// for (int j = 0; j < ne11; ++j) {
+// for (int i = 0; i < ne10; ++i) {
+// wdata[i*ne11 + j] = y[j*ne10 + i];
+// }
+// }
+//
+// {
+// const int64_t tt0 = ggml_time_us();
+// cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
+// 1500, 1500, 64,
+// 1.0, x, 64,
+// wdata, 1500,
+// 0.0, z, 1500);
+// const int64_t tt1 = ggml_time_us();
+// tsum += tt1 - tt0;
+// }
+//
+// // transpose z
+// for (int j = 0; j < ne1; ++j) {
+// for (int i = 0; i < ne0; ++i) {
+// wdata[i*ne1 + j] = z[j*ne0 + i];
+// }
+// }
+//
+// memcpy(z, wdata, ne0*ne1*sizeof(float));
+//
+// //cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
+// // ne0, ne1, 64,
+// // 1.0f,
+// // x, ne00,
+// // y, ne11,
+// // 0.0f,
+// // z, 1500);
+// }
+// }
+// printf("time = %f ms\n", tsum/1000.0);
+// return;
+// } else {
+// //cblas_sgemv(CblasRowMajor, CblasTrans, ne00, ne01, 1.0, src0->data, ne01, src1->data, 1, 0.0, dst->data, 1);
+// }
+//
+//#endif
+
+
+ if (nb01 >= nb00) {
+ // TODO: do not support transposed src1
+ assert(nb10 == sizeof(float));
+
+ // parallelize by src0 rows using ggml_vec_dot_f32
+
+ // total rows in src0
+ const int nr = ne01*ne02*ne03;
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src0 indices
+ const int i03 = ir/(ne02*ne01);
+ const int i02 = (ir - i03*ne02*ne01)/ne01;
+ const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+ for (int ic = 0; ic < ne11; ++ic) {
+ // src1 indices
+ const int i13 = i03;
+ const int i12 = i02;
+ const int i11 = ic;
+
+ // dst indices
+ const int i0 = i01;
+ const int i1 = i11;
+ const int i2 = i02;
+ const int i3 = i03;
+
+ ggml_vec_dot_f32(ne00,
+ (float *) ((char *) dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
+ (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)),
+ (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)));
+ }
+ }
+ } else {
+ // parallelize by src1 columns using ggml_vec_mad_f32
+ // each thread has its own work data
+ // during FINALIZE we accumulate all work data into dst
+
+ // total columns in src1
+ const int nc = ne10;
+
+ // columns per thread
+ const int dc = (nc + nth - 1)/nth;
+
+ // column range for this thread
+ const int ic0 = dc*ith;
+ const int ic1 = MIN(ic0 + dc, nc);
+
+ // work data for thread
+ const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
+ float * const wdata = params->wdata;
+
+ for (int i13 = 0; i13 < ne13; ++i13) {
+ for (int i12 = 0; i12 < ne12; ++i12) {
+ for (int i11 = 0; i11 < ne11; ++i11) {
+ for (int ic = ic0; ic < ic1; ++ic) {
+ // src1 indices
+ const int i10 = ic;
+
+ // src0 indices
+ const int i03 = i13;
+ const int i02 = i12;
+ const int i00 = ic;
+
+ // dst indices
+ const int i1 = i11;
+ const int i2 = i12;
+ const int i3 = i13;
+
+ assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
+
+ ggml_vec_mad_f32(ne01,
+ (float *) (wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0),
+ (float *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)),
+ *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13)));
+ }
+ }
+ }
+ }
+ }
+
+ //int64_t t1 = ggml_perf_time_us();
+ //static int64_t acc = 0;
+ //acc += t1 - t0;
+ //if (t1 - t0 > 10) {
+ // printf("\n");
+ // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
+ // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
+ // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
+ // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
+
+ // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
+ //}
+}
+
+void ggml_compute_forward_mul_mat_f16_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ int64_t t0 = ggml_perf_time_us();
+ UNUSED(t0);
+
+ const int ne00 = src0->ne[0];
+ const int ne01 = src0->ne[1];
+ const int ne02 = src0->ne[2];
+ const int ne03 = src0->ne[3];
+
+ const int ne10 = src1->ne[0];
+ const int ne11 = src1->ne[1];
+ const int ne12 = src1->ne[2];
+ const int ne13 = src1->ne[3];
+
+ const int ne0 = dst->ne[0];
+ const int ne1 = dst->ne[1];
+ const int ne2 = dst->ne[2];
+ const int ne3 = dst->ne[3];
+ const int ne = ne0*ne1*ne2*ne3;
+
+ const int nb00 = src0->nb[0];
+ const int nb01 = src0->nb[1];
+ const int nb02 = src0->nb[2];
+ const int nb03 = src0->nb[3];
+
+ const int nb10 = src1->nb[0];
+ const int nb11 = src1->nb[1];
+ const int nb12 = src1->nb[2];
+ const int nb13 = src1->nb[3];
+
+ const int nb0 = dst->nb[0];
+ const int nb1 = dst->nb[1];
+ const int nb2 = dst->nb[2];
+ const int nb3 = dst->nb[3];
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ assert(ne02 == ne12);
+ assert(ne03 == ne13);
+ assert(ne2 == ne12);
+ assert(ne3 == ne13);
+
+ // TODO: we don't support permuted src0
+ assert(nb00 == sizeof(ggml_fp16_t) || nb01 == sizeof(ggml_fp16_t));
+
+ // dst cannot be transposed or permuted
+ assert(nb0 == sizeof(float));
+ assert(nb0 <= nb1);
+ assert(nb1 <= nb2);
+ assert(nb2 <= nb3);
+
+ assert(ne0 == ne01);
+ assert(ne1 == ne11);
+ assert(ne2 == ne02);
+ assert(ne3 == ne03);
+
+ // nb01 >= nb00 - src0 is not transposed
+ // compute by src0 rows
+ //
+ // nb00 < nb01 - src0 is transposed
+ // compute by src0 columns
+
+ if (params->type == GGML_TASK_INIT) {
+ if (nb01 >= nb00) {
+ ggml_fp16_t * const wdata = params->wdata;
+
+ int id = 0;
+ for (int i13 = 0; i13 < ne13; ++i13) {
+ for (int i12 = 0; i12 < ne12; ++i12) {
+ for (int i11 = 0; i11 < ne11; ++i11) {
+ for (int i10 = 0; i10 < ne10; ++i10) {
+ wdata[id++] = ggml_fp32_to_fp16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
+ }
+ }
+ }
+ }
+
+ GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize);
+
+ return;
+ }
+
+ // TODO: fix this memset (wsize is overestimated)
+ memset(params->wdata, 0, params->wsize);
+ return;
+ }
+
+ if (params->type == GGML_TASK_FINALIZE) {
+ if (nb01 >= nb00) {
+ return;
+ }
+
+ // TODO: fix this memset (wsize is overestimated)
+ //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth);
+
+ ggml_fp16_t * const wdata = params->wdata;
+
+ // cols per thread
+ const int dc = (ne + nth - 1)/nth;
+
+ // col range for this thread
+ const int ic0 = dc*ith;
+ const int ic1 = MIN(ic0 + dc, ne);
+
+ for (int i = ic0; i < ic1; ++i) {
+ ((float *) dst->data)[i] = ggml_fp16_to_fp32(wdata[i]);
+ }
+
+ for (int k = 1; k < nth; k++) {
+ for (int i = ic0; i < ic1; ++i) {
+ ((float *) dst->data)[i] += ggml_fp16_to_fp32(wdata[(ne + CACHE_LINE_SIZE_F32)*k + i]);
+ }
+ }
+
+ return;
+ }
+
+ if (nb01 >= nb00) {
+ // fp16 -> half the size, so divide by 2
+ // TODO: do not support transposed src1
+ assert(nb10/2 == sizeof(ggml_fp16_t));
+
+ // parallelize by src0 rows using ggml_vec_dot_f32
+
+ // total rows in src0
+ const int nr = ne01*ne02*ne03;
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ ggml_fp16_t * wdata = params->wdata;
+
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src0 indices
+ const int i03 = ir/(ne02*ne01);
+ const int i02 = (ir - i03*ne02*ne01)/ne01;
+ const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+ const int i13 = i03;
+ const int i12 = i02;
+
+ const int i0 = i01;
+ const int i2 = i02;
+ const int i3 = i03;
+
+ ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
+ ggml_fp16_t * src1_col = wdata + (i13*ne12*ne11 + i12*ne11 + 0)*ne00;
+
+ float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
+
+ for (int ic = 0; ic < ne11; ++ic) {
+ assert(ne00 % 32 == 0);
+
+ ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
+ }
+ }
+ } else {
+ // parallelize by src1 columns using ggml_vec_mad_f32
+ // each thread has its own work data
+ // during FINALIZE we accumulate all work data into dst
+
+ // total columns in src1
+ const int nc = ne10;
+
+ // columns per thread
+ const int dc = (nc + nth - 1)/nth;
+
+ // column range for this thread
+ const int ic0 = dc*ith;
+ const int ic1 = MIN(ic0 + dc, nc);
+
+ // work data for thread
+ const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
+ ggml_fp16_t * const wdata = params->wdata;
+
+ for (int i13 = 0; i13 < ne13; ++i13) {
+ for (int i12 = 0; i12 < ne12; ++i12) {
+ for (int i11 = 0; i11 < ne11; ++i11) {
+ // dst indices
+ const int i1 = i11;
+ const int i2 = i12;
+ const int i3 = i13;
+
+ ggml_fp16_t * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0;
+
+ for (int ic = ic0; ic < ic1; ++ic) {
+ // src1 indices
+ const int i10 = ic;
+
+ // src0 indices
+ const int i03 = i13;
+ const int i02 = i12;
+ const int i00 = ic;
+
+ assert(sizeof(ggml_fp16_t)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
+
+ ggml_fp16_t * src0_col = (ggml_fp16_t *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03));
+ float src1_val = * (float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+
+ ggml_vec_mad_f16(ne01, dst_row, src0_col, src1_val);
+ }
+ }
+ }
+ }
+ }
+
+ //int64_t t1 = ggml_time_us();
+ //static int64_t acc = 0;
+ //acc += t1 - t0;
+ //if (t1 - t0 > 10) {
+ // printf("\n");
+ // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
+ // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
+ // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
+
+ // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
+ //}
+}
+
+void ggml_compute_forward_mul_mat(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F16:
+ {
+ ggml_compute_forward_mul_mat_f16_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_mul_mat_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_scale
+
+void ggml_compute_forward_scale_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+ GGML_ASSERT(ggml_is_scalar(src1));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ // scale factor
+ const float v = *(float *) src1->data;
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nc = src0->ne[0];
+ const int nr = ggml_nrows(src0);
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), v);
+ }
+}
+
+void ggml_compute_forward_scale(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_scale_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_cpy
+
+void ggml_compute_forward_cpy(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ ggml_compute_forward_dup(params, src0, dst);
+}
+
+// ggml_compute_forward_reshape
+
+void ggml_compute_forward_reshape(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ // NOP
+ UNUSED(params);
+ UNUSED(src0);
+ UNUSED(dst);
+}
+
+// ggml_compute_forward_view
+
+void ggml_compute_forward_view(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0) {
+ // NOP
+ UNUSED(params);
+ UNUSED(src0);
+}
+
+// ggml_compute_forward_permute
+
+void ggml_compute_forward_permute(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0) {
+ // NOP
+ UNUSED(params);
+ UNUSED(src0);
+}
+
+// ggml_compute_forward_transpose
+
+void ggml_compute_forward_transpose(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0) {
+ // NOP
+ UNUSED(params);
+ UNUSED(src0);
+}
+
+// ggml_compute_forward_get_rows
+
+void ggml_compute_forward_get_rows_f16(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int nc = src0->ne[0];
+ const int nr = ggml_nelements(src1);
+
+ assert( dst->ne[0] == nc);
+ assert( dst->ne[1] == nr);
+ assert(src0->nb[0] == sizeof(ggml_fp16_t));
+
+ for (int i = 0; i < nr; ++i) {
+ const int r = ((int32_t *) src1->data)[i];
+
+ for (int j = 0; j < nc; ++j) {
+ ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j];
+ ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = ggml_fp16_to_fp32(v);
+ }
+ }
+}
+
+void ggml_compute_forward_get_rows_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int nc = src0->ne[0];
+ const int nr = ggml_nelements(src1);
+
+ assert( dst->ne[0] == nc);
+ assert( dst->ne[1] == nr);
+ assert(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < nr; ++i) {
+ const int r = ((int32_t *) src1->data)[i];
+
+ ggml_vec_cpy_f32(nc,
+ (float *) ((char *) dst->data + i*dst->nb[1]),
+ (float *) ((char *) src0->data + r*src0->nb[1]));
+ }
+}
+
+void ggml_compute_forward_get_rows(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F16:
+ {
+ ggml_compute_forward_get_rows_f16(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_get_rows_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_diag_mask_inf
+
+void ggml_compute_forward_diag_mask_inf_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(src1->type == GGML_TYPE_I32);
+ assert(ggml_nelements(src1) == 1);
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n_past = ((int32_t *) src1->data)[0];
+
+ // TODO: handle transposed/permuted matrices
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+ const int nr = src0->ne[1];
+ const int nz = n/nr;
+
+ assert( dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+
+ for (int k = 0; k < nz; k++) {
+ for (int j = 0; j < nr; j++) {
+ for (int i = n_past; i < nc; i++) {
+ if (i > n_past + j) {
+ *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = -INFINITY;
+ }
+ }
+ }
+ }
+}
+
+void ggml_compute_forward_diag_mask_inf(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_soft_max
+
+void ggml_compute_forward_soft_max_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ // TODO: handle transposed/permuted matrices
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nc = src0->ne[0];
+ const int nr = ggml_nrows(src0);
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ float *p = (float *)((char *) dst->data + i1*dst->nb[1]);
+
+#ifndef NDEBUG
+ for (int i = 0; i < nc; ++i) {
+ assert(!isnan(p[i]));
+ }
+#endif
+
+ float max = -INFINITY;
+ for (int i = 0; i < nc; i++) {
+ max = MAX(max, p[i]);
+ }
+
+ ggml_float sum = 0.0;
+ for (int i = 0; i < nc; i++) {
+ const ggml_float v = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max);
+ sum += v;
+ p[i] = v;
+ }
+
+ assert(sum > 0.0f);
+
+ sum = 1.0/sum;
+ ggml_vec_scale_f32(nc, p, sum);
+
+#ifndef NDEBUG
+ for (int i = 0; i < nc; ++i) {
+ assert(!isnan(p[i]));
+ assert(!isinf(p[i]));
+ }
+#endif
+ }
+}
+
+void ggml_compute_forward_soft_max(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_soft_max_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_rope
+
+void ggml_compute_forward_rope_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(src1->type == GGML_TYPE_I32);
+ assert(ggml_nelements(src1) == 3);
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n_past = ((int32_t *) src1->data)[0];
+ const int n_dims = ((int32_t *) src1->data)[1];
+ const int mode = ((int32_t *) src1->data)[2];
+
+ //const int ne0 = src0->ne[0];
+ const int ne1 = src0->ne[1];
+ const int ne2 = src0->ne[2];
+ const int ne3 = src0->ne[3];
+
+ const int nb0 = src0->nb[0];
+ const int nb1 = src0->nb[1];
+ const int nb2 = src0->nb[2];
+ const int nb3 = src0->nb[3];
+
+ //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
+ //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
+
+ assert(nb0 == sizeof(float));
+
+ // TODO: optimize
+ for (int i3 = 0; i3 < ne3; i3++) {
+ for (int i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
+ const int p = (mode == 0 ? n_past + i2 : i2);
+ for (int i1 = 0; i1 < ne1; i1++) {
+ for (int i0 = 0; i0 < n_dims; i0 += 2) {
+ const double theta = pow(10000.0, ((double)-i0)/n_dims);
+
+ const double cos_theta = cos(p*theta);
+ const double sin_theta = sin(p*theta);
+
+ const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ double x0 = src[0];
+ double x1 = src[1];
+
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
+ dst_data[1] = x0*sin_theta + x1*cos_theta;
+ }
+ }
+ }
+ }
+}
+
+void ggml_compute_forward_rope(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_rope_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_conv_1d_1s
+
+void ggml_compute_forward_conv_1d_1s_f16_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ int64_t t0 = ggml_perf_time_us();
+ UNUSED(t0);
+
+ const int ne00 = src0->ne[0];
+ const int ne01 = src0->ne[1];
+ const int ne02 = src0->ne[2];
+ //const int ne03 = src0->ne[3];
+
+ const int ne10 = src1->ne[0];
+ const int ne11 = src1->ne[1];
+ //const int ne12 = src1->ne[2];
+ //const int ne13 = src1->ne[3];
+
+ //const int ne0 = dst->ne[0];
+ //const int ne1 = dst->ne[1];
+ //const int ne2 = dst->ne[2];
+ //const int ne3 = dst->ne[3];
+ //const int ne = ne0*ne1*ne2*ne3;
+
+ const int nb00 = src0->nb[0];
+ const int nb01 = src0->nb[1];
+ const int nb02 = src0->nb[2];
+ //const int nb03 = src0->nb[3];
+
+ const int nb10 = src1->nb[0];
+ const int nb11 = src1->nb[1];
+ //const int nb12 = src1->nb[2];
+ //const int nb13 = src1->nb[3];
+
+ //const int nb0 = dst->nb[0];
+ const int nb1 = dst->nb[1];
+ //const int nb2 = dst->nb[2];
+ //const int nb3 = dst->nb[3];
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nk = ne00;
+ const int nh = nk/2;
+
+ const int ew0 = ggml_up32(ne01);
+
+ GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(nb10 == sizeof(float));
+
+ // WHISPER
+ if (params->type == GGML_TASK_INIT) {
+ // TODO: fix this memset (wsize is overestimated)
+ memset(params->wdata, 0, params->wsize);
+
+ // prepare kernel data (src0)
+ {
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
+ ggml_fp16_t * dst_data = wdata + i02*ew0*ne00;
+ for (int i00 = 0; i00 < ne00; i00++) {
+ dst_data[i00*ew0 + i01] = src[i00];
+ }
+ }
+ }
+ }
+
+ // prepare source data (src1)
+ {
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00;
+
+ for (int i11 = 0; i11 < ne11; i11++) {
+ const float * const src = (float *)((char *) src1->data + i11*nb11);
+ ggml_fp16_t * dst_data = wdata;
+ for (int i10 = 0; i10 < ne10; i10++) {
+ dst_data[(i10 + nh)*ew0 + i11] = ggml_fp32_to_fp16(src[i10]);
+ }
+ }
+ }
+
+ return;
+ }
+
+ if (params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ // total rows in dst
+ const int nr = ne02;
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ float * dst_data = (float *)((char *) dst->data + i1*nb1);
+ for (int i0 = 0; i0 < ne10; ++i0) {
+ dst_data[i0] = 0;
+ for (int k = -nh; k <= nh; k++) {
+ float v = 0.0f;
+ ggml_vec_dot_f16(ew0, &v,
+ (ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0,
+ (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
+
+ dst_data[i0] += v;
+ }
+ }
+ }
+}
+
+void ggml_compute_forward_conv_1d_1s_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ int64_t t0 = ggml_perf_time_us();
+ UNUSED(t0);
+
+ const int ne00 = src0->ne[0];
+ const int ne01 = src0->ne[1];
+ const int ne02 = src0->ne[2];
+ //const int ne03 = src0->ne[3];
+
+ const int ne10 = src1->ne[0];
+ const int ne11 = src1->ne[1];
+ //const int ne12 = src1->ne[2];
+ //const int ne13 = src1->ne[3];
+
+ //const int ne0 = dst->ne[0];
+ //const int ne1 = dst->ne[1];
+ //const int ne2 = dst->ne[2];
+ //const int ne3 = dst->ne[3];
+ //const int ne = ne0*ne1*ne2*ne3;
+
+ const int nb00 = src0->nb[0];
+ const int nb01 = src0->nb[1];
+ const int nb02 = src0->nb[2];
+ //const int nb03 = src0->nb[3];
+
+ const int nb10 = src1->nb[0];
+ const int nb11 = src1->nb[1];
+ //const int nb12 = src1->nb[2];
+ //const int nb13 = src1->nb[3];
+
+ //const int nb0 = dst->nb[0];
+ const int nb1 = dst->nb[1];
+ //const int nb2 = dst->nb[2];
+ //const int nb3 = dst->nb[3];
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nk = ne00;
+ const int nh = nk/2;
+
+ const int ew0 = ggml_up32(ne01);
+
+ GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
+ GGML_ASSERT(nb00 == sizeof(float));
+ GGML_ASSERT(nb10 == sizeof(float));
+
+ // WHISPER
+ if (params->type == GGML_TASK_INIT) {
+ // TODO: fix this memset (wsize is overestimated)
+ memset(params->wdata, 0, params->wsize);
+
+ // prepare kernel data (src0)
+ {
+ float * const wdata = (float *) params->wdata + 0;
+
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
+ float * dst_data = wdata + i02*ew0*ne00;
+ for (int i00 = 0; i00 < ne00; i00++) {
+ dst_data[i00*ew0 + i01] = src[i00];
+ }
+ }
+ }
+ }
+
+ // prepare source data (src1)
+ {
+ float * const wdata = (float *) params->wdata + ne02*ew0*ne00;
+
+ for (int i11 = 0; i11 < ne11; i11++) {
+ const float * const src = (float *)((char *) src1->data + i11*nb11);
+ float * dst_data = wdata;
+ for (int i10 = 0; i10 < ne10; i10++) {
+ dst_data[(i10 + nh)*ew0 + i11] = src[i10];
+ }
+ }
+ }
+
+ return;
+ }
+
+ if (params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ // total rows in dst
+ const int nr = ne02;
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ float * dst_data = (float *)((char *) dst->data + i1*nb1);
+ for (int i0 = 0; i0 < ne10; ++i0) {
+ dst_data[i0] = 0;
+ for (int k = -nh; k <= nh; k++) {
+ float v = 0.0f;
+ ggml_vec_dot_f32(ew0, &v,
+ (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0,
+ (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
+
+ dst_data[i0] += v;
+ }
+ }
+ }
+}
+
+void ggml_compute_forward_conv_1d_1s(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F16:
+ {
+ ggml_compute_forward_conv_1d_1s_f16_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_COUNT:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_conv_1d_2s
+
+void ggml_compute_forward_conv_1d_2s_f16_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ int64_t t0 = ggml_perf_time_us();
+ UNUSED(t0);
+
+ const int ne00 = src0->ne[0];
+ const int ne01 = src0->ne[1];
+ const int ne02 = src0->ne[2];
+ //const int ne03 = src0->ne[3];
+
+ const int ne10 = src1->ne[0];
+ const int ne11 = src1->ne[1];
+ //const int ne12 = src1->ne[2];
+ //const int ne13 = src1->ne[3];
+
+ //const int ne0 = dst->ne[0];
+ //const int ne1 = dst->ne[1];
+ //const int ne2 = dst->ne[2];
+ //const int ne3 = dst->ne[3];
+ //const int ne = ne0*ne1*ne2*ne3;
+
+ const int nb00 = src0->nb[0];
+ const int nb01 = src0->nb[1];
+ const int nb02 = src0->nb[2];
+ //const int nb03 = src0->nb[3];
+
+ const int nb10 = src1->nb[0];
+ const int nb11 = src1->nb[1];
+ //const int nb12 = src1->nb[2];
+ //const int nb13 = src1->nb[3];
+
+ //const int nb0 = dst->nb[0];
+ const int nb1 = dst->nb[1];
+ //const int nb2 = dst->nb[2];
+ //const int nb3 = dst->nb[3];
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nk = ne00;
+ const int nh = nk/2;
+
+ const int ew0 = ggml_up32(ne01);
+
+ GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(nb10 == sizeof(float));
+
+ // WHISPER
+ if (params->type == GGML_TASK_INIT) {
+ // TODO: fix this memset (wsize is overestimated)
+ memset(params->wdata, 0, params->wsize);
+
+ // prepare kernel data (src0)
+ {
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
+ ggml_fp16_t * dst_data = wdata + i02*ew0*ne00;
+ for (int i00 = 0; i00 < ne00; i00++) {
+ dst_data[i00*ew0 + i01] = src[i00];
+ }
+ }
+ }
+ }
+
+ // prepare source data (src1)
+ {
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00;
+
+ for (int i11 = 0; i11 < ne11; i11++) {
+ const float * const src = (float *)((char *) src1->data + i11*nb11);
+ ggml_fp16_t * dst_data = wdata;
+ for (int i10 = 0; i10 < ne10; i10++) {
+ dst_data[(i10 + nh)*ew0 + i11] = ggml_fp32_to_fp16(src[i10]);
+ }
+ }
+ }
+
+ return;
+ }
+
+ if (params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ // total rows in dst
+ const int nr = ne02;
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ float * dst_data = (float *)((char *) dst->data + i1*nb1);
+ for (int i0 = 0; i0 < ne10; i0 += 2) {
+ dst_data[i0/2] = 0;
+ for (int k = -nh; k <= nh; k++) {
+ float v = 0.0f;
+ ggml_vec_dot_f16(ew0, &v,
+ (ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0,
+ (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
+
+ dst_data[i0/2] += v;
+ }
+ }
+ }
+}
+
+void ggml_compute_forward_conv_1d_2s_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ int64_t t0 = ggml_perf_time_us();
+ UNUSED(t0);
+
+ const int ne00 = src0->ne[0];
+ const int ne01 = src0->ne[1];
+ const int ne02 = src0->ne[2];
+ //const int ne03 = src0->ne[3];
+
+ const int ne10 = src1->ne[0];
+ const int ne11 = src1->ne[1];
+ //const int ne12 = src1->ne[2];
+ //const int ne13 = src1->ne[3];
+
+ //const int ne0 = dst->ne[0];
+ //const int ne1 = dst->ne[1];
+ //const int ne2 = dst->ne[2];
+ //const int ne3 = dst->ne[3];
+ //const int ne = ne0*ne1*ne2*ne3;
+
+ const int nb00 = src0->nb[0];
+ const int nb01 = src0->nb[1];
+ const int nb02 = src0->nb[2];
+ //const int nb03 = src0->nb[3];
+
+ const int nb10 = src1->nb[0];
+ const int nb11 = src1->nb[1];
+ //const int nb12 = src1->nb[2];
+ //const int nb13 = src1->nb[3];
+
+ //const int nb0 = dst->nb[0];
+ const int nb1 = dst->nb[1];
+ //const int nb2 = dst->nb[2];
+ //const int nb3 = dst->nb[3];
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nk = ne00;
+ const int nh = nk/2;
+
+ const int ew0 = ggml_up32(ne01);
+
+ GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
+ GGML_ASSERT(nb00 == sizeof(float));
+ GGML_ASSERT(nb10 == sizeof(float));
+
+ // WHISPER
+ if (params->type == GGML_TASK_INIT) {
+ // TODO: fix this memset (wsize is overestimated)
+ memset(params->wdata, 0, params->wsize);
+
+ // prepare kernel data (src0)
+ {
+ float * const wdata = (float *) params->wdata + 0;
+
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
+ float * dst_data = wdata + i02*ew0*ne00;
+ for (int i00 = 0; i00 < ne00; i00++) {
+ dst_data[i00*ew0 + i01] = src[i00];
+ }
+ }
+ }
+ }
+
+ // prepare source data (src1)
+ {
+ float * const wdata = (float *) params->wdata + ne02*ew0*ne00;
+
+ for (int i11 = 0; i11 < ne11; i11++) {
+ const float * const src = (float *)((char *) src1->data + i11*nb11);
+ float * dst_data = wdata;
+ for (int i10 = 0; i10 < ne10; i10++) {
+ dst_data[(i10 + nh)*ew0 + i11] = src[i10];
+ }
+ }
+ }
+
+ return;
+ }
+
+ if (params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ // total rows in dst
+ const int nr = ne02;
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ float * dst_data = (float *)((char *) dst->data + i1*nb1);
+ for (int i0 = 0; i0 < ne10; i0 += 2) {
+ dst_data[i0/2] = 0;
+ for (int k = -nh; k <= nh; k++) {
+ float v = 0.0f;
+ ggml_vec_dot_f32(ew0, &v,
+ (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0,
+ (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
+
+ dst_data[i0/2] += v;
+ }
+ }
+ }
+}
+
+void ggml_compute_forward_conv_1d_2s(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F16:
+ {
+ ggml_compute_forward_conv_1d_2s_f16_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_COUNT:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+/////////////////////////////////
+
+void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
+ assert(params);
+
+ switch (tensor->op) {
+ case GGML_OP_DUP:
+ {
+ ggml_compute_forward_dup(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_ADD:
+ {
+ ggml_compute_forward_add(params, tensor->src0, tensor->src1, tensor);
+ } break;
+ case GGML_OP_SUB:
+ {
+ ggml_compute_forward_sub(params, tensor->src0, tensor->src1, tensor);
+ } break;
+ case GGML_OP_MUL:
+ {
+ ggml_compute_forward_mul(params, tensor->src0, tensor->src1, tensor);
+ } break;
+ case GGML_OP_DIV:
+ {
+ ggml_compute_forward_div(params, tensor->src0, tensor->src1, tensor);
+ } break;
+ case GGML_OP_SQR:
+ {
+ ggml_compute_forward_sqr(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_SQRT:
+ {
+ ggml_compute_forward_sqrt(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_SUM:
+ {
+ ggml_compute_forward_sum(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_MEAN:
+ {
+ ggml_compute_forward_mean(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_REPEAT:
+ {
+ ggml_compute_forward_repeat(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_ABS:
+ {
+ ggml_compute_forward_abs(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_SGN:
+ {
+ ggml_compute_forward_sgn(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_NEG:
+ {
+ ggml_compute_forward_neg(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_STEP:
+ {
+ ggml_compute_forward_step(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_RELU:
+ {
+ ggml_compute_forward_relu(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_GELU:
+ {
+ ggml_compute_forward_gelu(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_NORM:
+ {
+ ggml_compute_forward_norm(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_MUL_MAT:
+ {
+ ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
+ } break;
+ case GGML_OP_SCALE:
+ {
+ ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor);
+ } break;
+ case GGML_OP_CPY:
+ {
+ ggml_compute_forward_cpy(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_RESHAPE:
+ {
+ ggml_compute_forward_reshape(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_VIEW:
+ {
+ ggml_compute_forward_view(params, tensor->src0);
+ } break;
+ case GGML_OP_PERMUTE:
+ {
+ ggml_compute_forward_permute(params, tensor->src0);
+ } break;
+ case GGML_OP_TRANSPOSE:
+ {
+ ggml_compute_forward_transpose(params, tensor->src0);
+ } break;
+ case GGML_OP_GET_ROWS:
+ {
+ ggml_compute_forward_get_rows(params, tensor->src0, tensor->src1, tensor);
+ } break;
+ case GGML_OP_DIAG_MASK_INF:
+ {
+ ggml_compute_forward_diag_mask_inf(params, tensor->src0, tensor->src1, tensor);
+ } break;
+ case GGML_OP_SOFT_MAX:
+ {
+ ggml_compute_forward_soft_max(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_ROPE:
+ {
+ ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
+ } break;
+ case GGML_OP_CONV_1D_1S:
+ {
+ ggml_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor);
+ } break;
+ case GGML_OP_CONV_1D_2S:
+ {
+ ggml_compute_forward_conv_1d_2s(params, tensor->src0, tensor->src1, tensor);
+ } break;
+ case GGML_OP_NONE:
+ {
+ // nop
+ } break;
+ case GGML_OP_COUNT:
+ {
+ assert(false);
+ } break;
+ };
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) {
+ struct ggml_tensor * src0 = tensor->src0;
+ struct ggml_tensor * src1 = tensor->src1;
+
+ switch (tensor->op) {
+ case GGML_OP_DUP:
+ {
+ if (src0->grad) {
+ src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+ }
+ } break;
+ case GGML_OP_ADD:
+ {
+ if (src0->grad) {
+ src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+ }
+ if (src1->grad) {
+ src1->grad = ggml_add_impl(ctx, src1->grad, tensor->grad, inplace);
+ }
+ } break;
+ case GGML_OP_SUB:
+ {
+ if (src0->grad) {
+ src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+ }
+ if (src1->grad) {
+ src1->grad = ggml_sub_impl(ctx, src1->grad, tensor->grad, inplace);
+ }
+ } break;
+ case GGML_OP_MUL:
+ {
+ if (src0->grad) {
+ src0->grad =
+ ggml_add_impl(ctx,
+ src0->grad,
+ ggml_mul(ctx, src1, tensor->grad),
+ inplace);
+ }
+ if (src1->grad) {
+ src1->grad =
+ ggml_add_impl(ctx,
+ src1->grad,
+ ggml_mul(ctx, src0, tensor->grad),
+ inplace);
+ }
+ } break;
+ case GGML_OP_DIV:
+ {
+ if (src0->grad) {
+ src0->grad =
+ ggml_add_impl(ctx,
+ src0->grad,
+ ggml_div(ctx, tensor->grad, src1),
+ inplace);
+ }
+ if (src1->grad) {
+ src1->grad =
+ ggml_sub_impl(ctx,
+ src1->grad,
+ ggml_mul(ctx,
+ tensor->grad,
+ ggml_div(ctx, tensor, src1)),
+ inplace);
+ }
+ } break;
+ case GGML_OP_SQR:
+ {
+ if (src0->grad) {
+ src0->grad =
+ ggml_add_impl(ctx,
+ src0->grad,
+ ggml_mul(ctx,
+ ggml_mul(ctx, src0, tensor->grad),
+ ggml_repeat(ctx, ggml_new_f32(ctx, 2.0f), src0)),
+ inplace);
+ }
+ } break;
+ case GGML_OP_SQRT:
+ {
+ if (src0->grad) {
+ src0->grad =
+ ggml_add_impl(ctx,
+ src0->grad,
+ ggml_div(ctx,
+ ggml_repeat(ctx, ggml_new_f32(ctx, 0.5f), tensor),
+ tensor),
+ inplace);
+ }
+ } break;
+ case GGML_OP_SUM:
+ {
+ if (src0->grad) {
+ src0->grad =
+ ggml_add_impl(ctx,
+ src0->grad,
+ ggml_repeat(ctx, tensor->grad, src0->grad),
+ inplace);
+ }
+ } break;
+ case GGML_OP_MEAN:
+ {
+ assert(false); // TODO: implement
+ } break;
+ case GGML_OP_REPEAT:
+ {
+ if (src0->grad) {
+ src0->grad =
+ ggml_add_impl(ctx,
+ src0->grad,
+ ggml_sum(ctx, tensor->grad),
+ inplace);
+ }
+ } break;
+ case GGML_OP_ABS:
+ {
+ if (src0->grad) {
+ src0->grad =
+ ggml_add_impl(ctx,
+ src0->grad,
+ ggml_mul(ctx,
+ ggml_sgn(ctx, src0),
+ tensor->grad),
+ inplace);
+ }
+ } break;
+ case GGML_OP_SGN:
+ {
+ if (src0->grad) {
+ // noop
+ }
+ } break;
+ case GGML_OP_NEG:
+ {
+ if (src0->grad) {
+ src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace);
+ }
+ } break;
+ case GGML_OP_STEP:
+ {
+ if (src0->grad) {
+ // noop
+ }
+ } break;
+ case GGML_OP_RELU:
+ {
+ if (src0->grad) {
+ src0->grad = ggml_sub_impl(ctx,
+ src0->grad,
+ ggml_mul(ctx,
+ ggml_step(ctx, src0),
+ tensor->grad),
+ inplace);
+ }
+ } break;
+ case GGML_OP_GELU:
+ {
+ assert(false); // TODO: not implemented
+ } break;
+ case GGML_OP_NORM:
+ {
+ assert(false); // TODO: not implemented
+ } break;
+ case GGML_OP_MUL_MAT:
+ {
+ if (src0->grad) {
+ // TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad);
+ assert(false);
+ }
+ if (src1->grad) {
+ src1->grad =
+ ggml_add_impl(ctx,
+ src1->grad,
+ // TODO: fix transpose, the node will break the graph connections
+ ggml_mul_mat(ctx, ggml_transpose(ctx, src0), tensor->grad),
+ inplace);
+ }
+ } break;
+ case GGML_OP_SCALE:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_CPY:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_RESHAPE:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_VIEW:
+ {
+ GGML_ASSERT(false); // not supported
+ } break;
+ case GGML_OP_PERMUTE:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_TRANSPOSE:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_GET_ROWS:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_DIAG_MASK_INF:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_SOFT_MAX:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_ROPE:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_CONV_1D_1S:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_CONV_1D_2S:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_NONE:
+ {
+ // nop
+ } break;
+ case GGML_OP_COUNT:
+ {
+ GGML_ASSERT(false);
+ } break;
+ };
+}
+
+void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
+ if (node->grad == NULL) {
+ // this usually happens when we generate intermediate nodes from constants in the backward pass
+ // it can also happen during forward pass, if the user performs computations with constants
+ if (node->op != GGML_OP_NONE) {
+ //GGML_PRINT_DEBUG("%s: warning: node %p has no grad, but op %d\n", __func__, (void *) node, node->op);
+ }
+ }
+
+ // check if already visited
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ if (cgraph->nodes[i] == node) {
+ return;
+ }
+ }
+
+ for (int i = 0; i < cgraph->n_leafs; i++) {
+ if (cgraph->leafs[i] == node) {
+ return;
+ }
+ }
+
+ if (node->src0) {
+ ggml_visit_parents(cgraph, node->src0);
+ }
+
+ if (node->src1) {
+ ggml_visit_parents(cgraph, node->src1);
+ }
+
+ if (node->op == GGML_OP_NONE && node->grad == NULL) {
+ // reached a leaf node, not part of the gradient graph (e.g. a constant)
+ assert(cgraph->n_leafs < GGML_MAX_NODES);
+
+ cgraph->leafs[cgraph->n_leafs] = node;
+ cgraph->n_leafs++;
+ } else {
+ assert(cgraph->n_nodes < GGML_MAX_NODES);
+
+ cgraph->nodes[cgraph->n_nodes] = node;
+ cgraph->grads[cgraph->n_nodes] = node->grad;
+ cgraph->n_nodes++;
+ }
+}
+
+void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
+ if (!expand) {
+ cgraph->n_nodes = 0;
+ cgraph->n_leafs = 0;
+ }
+
+ const int n0 = cgraph->n_nodes;
+ UNUSED(n0);
+
+ ggml_visit_parents(cgraph, tensor);
+
+ const int n_new = cgraph->n_nodes - n0;
+ GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new);
+
+ if (n_new > 0) {
+ // the last added node should always be starting point
+ assert(cgraph->nodes[cgraph->n_nodes - 1] == tensor);
+ }
+}
+
+void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) {
+ ggml_build_forward_impl(cgraph, tensor, true);
+}
+
+struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
+ struct ggml_cgraph result = {
+ /*.n_nodes =*/ 0,
+ /*.n_leafs =*/ 0,
+ /*.n_threads =*/ 0,
+ /*.work_size =*/ 0,
+ /*.work =*/ NULL,
+ /*.nodes =*/ { NULL },
+ /*.grads =*/ { NULL },
+ /*.leafs =*/ { NULL },
+ /*.perf_runs =*/ 0,
+ /*.perf_cycles =*/ 0,
+ /*.perf_time_us =*/ 0,
+ };
+
+ ggml_build_forward_impl(&result, tensor, false);
+
+ return result;
+}
+
+struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) {
+ struct ggml_cgraph result = *gf;
+
+ assert(gf->n_nodes > 0);
+
+ // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph
+ if (keep) {
+ for (int i = 0; i < gf->n_nodes; i++) {
+ struct ggml_tensor * node = gf->nodes[i];
+
+ if (node->grad) {
+ node->grad = ggml_dup_tensor(ctx, node);
+ gf->grads[i] = node->grad;
+ }
+ }
+ }
+
+ for (int i = gf->n_nodes - 1; i >= 0; i--) {
+ struct ggml_tensor * node = gf->nodes[i];
+
+ // because we detached the grad nodes from the original graph, we can afford inplace operations
+ if (node->grad) {
+ ggml_compute_backward(ctx, node, keep);
+ }
+ }
+
+ for (int i = gf->n_nodes - 1; i >= 0; i--) {
+ struct ggml_tensor * node = gf->nodes[i];
+
+ if (node->is_param) {
+ GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
+ ggml_build_forward_impl(&result, node->grad, true);
+ }
+ }
+
+ return result;
+}
+
+//
+// thread data
+//
+// synchronization is done via busy loops
+// I tried using spin locks, but not sure how to use them correctly - the things I tried were slower than busy loops
+//
+
+#ifdef __APPLE__
+
+//#include <os/lock.h>
+
+//typedef os_unfair_lock ggml_lock_t;
+//
+//#define ggml_lock_init(x) UNUSED(x)
+//#define ggml_lock_destroy(x) UNUSED(x)
+//#define ggml_lock_lock os_unfair_lock_lock
+//#define ggml_lock_unlock os_unfair_lock_unlock
+//
+//#define GGML_LOCK_INITIALIZER OS_UNFAIR_LOCK_INIT
+
+typedef int ggml_lock_t;
+
+#define ggml_lock_init(x) UNUSED(x)
+#define ggml_lock_destroy(x) UNUSED(x)
+#define ggml_lock_lock(x) UNUSED(x)
+#define ggml_lock_unlock(x) UNUSED(x)
+
+#define GGML_LOCK_INITIALIZER 0
+
+#else
+
+//typedef pthread_spinlock_t ggml_lock_t;
+
+//#define ggml_lock_init(x) pthread_spin_init(x, PTHREAD_PROCESS_PRIVATE)
+//#define ggml_lock_destroy pthread_spin_destroy
+//#define ggml_lock_lock pthread_spin_lock
+//#define ggml_lock_unlock pthread_spin_unlock
+
+typedef int ggml_lock_t;
+
+#define ggml_lock_init(x) UNUSED(x)
+#define ggml_lock_destroy(x) UNUSED(x)
+#define ggml_lock_lock(x) UNUSED(x)
+#define ggml_lock_unlock(x) UNUSED(x)
+
+#define GGML_LOCK_INITIALIZER 0
+
+#endif
+
+struct ggml_compute_state_shared {
+ ggml_lock_t spin;
+
+ int n_threads;
+
+ // synchronization primitives
+ atomic_int n_ready;
+ atomic_bool has_work;
+ atomic_bool stop; // stop all threads
+};
+
+struct ggml_compute_state {
+ pthread_t thrd;
+
+ struct ggml_compute_params params;
+ struct ggml_tensor * node;
+
+ struct ggml_compute_state_shared * shared;
+};
+
+// function used by each compute thread
+void * ggml_graph_compute_one(void * data) {
+ struct ggml_compute_state * state = (struct ggml_compute_state *) data;
+
+ ggml_compute_forward(&state->params, state->node);
+
+ return NULL;
+}
+
+void * ggml_graph_compute_thread(void * data) {
+ struct ggml_compute_state * state = (struct ggml_compute_state *) data;
+
+ const int n_threads = state->shared->n_threads;
+
+ while (true) {
+ if (atomic_fetch_add(&state->shared->n_ready, 1) == n_threads - 1) {
+ atomic_store(&state->shared->has_work, false);
+ } else {
+ while (atomic_load(&state->shared->has_work)) {
+ if (atomic_load(&state->shared->stop)) {
+ return NULL;
+ }
+ ggml_lock_lock (&state->shared->spin);
+ ggml_lock_unlock(&state->shared->spin);
+ }
+ }
+
+ atomic_fetch_sub(&state->shared->n_ready, 1);
+
+ // wait for work
+ while (!atomic_load(&state->shared->has_work)) {
+ if (atomic_load(&state->shared->stop)) {
+ return NULL;
+ }
+ ggml_lock_lock (&state->shared->spin);
+ ggml_lock_unlock(&state->shared->spin);
+ }
+
+ // check if we should stop
+ if (atomic_load(&state->shared->stop)) {
+ break;
+ }
+
+ if (state->node) {
+ ggml_compute_forward(&state->params, state->node);
+ state->node = NULL;
+ } else {
+ break;
+ }
+ }
+
+ return NULL;
+}
+
+void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
+ if (cgraph->n_threads <= 0) {
+ cgraph->n_threads = 8;
+ }
+
+ const int n_threads = cgraph->n_threads;
+
+ struct ggml_compute_state_shared state_shared = {
+ /*.spin =*/ GGML_LOCK_INITIALIZER,
+ /*.n_threads =*/ n_threads,
+ /*.n_ready =*/ 0,
+ /*.has_work =*/ false,
+ /*.stop =*/ false,
+ };
+ struct ggml_compute_state * workers = n_threads > 1 ? alloca(sizeof(struct ggml_compute_state)*(n_threads - 1)) : NULL;
+
+ // create thread pool
+ if (n_threads > 1) {
+ ggml_lock_init(&state_shared.spin);
+
+ atomic_store(&state_shared.has_work, true);
+
+ for (int j = 0; j < n_threads - 1; j++) {
+ workers[j] = (struct ggml_compute_state) {
+ .thrd = 0,
+ .params = {
+ .type = GGML_TASK_COMPUTE,
+ .ith = j + 1,
+ .nth = n_threads,
+ .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
+ .wdata = cgraph->work ? cgraph->work->data : NULL,
+ },
+ .node = NULL,
+ .shared = &state_shared,
+ };
+ int rc = pthread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
+ assert(rc == 0);
+ UNUSED(rc);
+ }
+ }
+
+ // initialize tasks + work buffer
+ {
+ size_t work_size = 0;
+
+ // thread scheduling for the different operations
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ struct ggml_tensor * node = cgraph->nodes[i];
+
+ switch (node->op) {
+ case GGML_OP_DUP:
+ case GGML_OP_ADD:
+ case GGML_OP_SUB:
+ case GGML_OP_MUL:
+ case GGML_OP_DIV:
+ case GGML_OP_SQR:
+ case GGML_OP_SQRT:
+ case GGML_OP_SUM:
+ case GGML_OP_MEAN:
+ case GGML_OP_REPEAT:
+ case GGML_OP_ABS:
+ case GGML_OP_SGN:
+ case GGML_OP_NEG:
+ case GGML_OP_STEP:
+ case GGML_OP_RELU:
+ {
+ node->n_tasks = 1;
+ } break;
+ case GGML_OP_GELU:
+ {
+ node->n_tasks = MIN(n_threads, ggml_nrows(node->src0));
+ } break;
+ case GGML_OP_NORM:
+ {
+ node->n_tasks = 1;
+ } break;
+ case GGML_OP_MUL_MAT:
+ {
+ // TODO: use different scheduling for different matrix sizes
+ node->n_tasks = n_threads;
+
+ size_t cur = 0;
+
+ // TODO: better way to determine if the matrix is transposed
+ if (node->src0->nb[1] < node->src0->nb[0]) {
+ cur = ggml_nbytes(node)*node->n_tasks; // TODO: this can become (n_tasks-1)
+ } else {
+ if (node->src0->type == GGML_TYPE_F16 &&
+ node->src1->type == GGML_TYPE_F32) {
+ cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1);
+ } else if (node->src0->type == GGML_TYPE_F32 &&
+ node->src1->type == GGML_TYPE_F32) {
+ cur = 0;
+ } else {
+ GGML_ASSERT(false);
+ }
+ }
+
+ work_size = MAX(work_size, cur);
+ } break;
+ case GGML_OP_SCALE:
+ {
+ node->n_tasks = MIN(n_threads, ggml_nrows(node->src0));
+ } break;
+ case GGML_OP_CPY:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_PERMUTE:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_GET_ROWS:
+ case GGML_OP_DIAG_MASK_INF:
+ {
+ node->n_tasks = 1;
+ } break;
+ case GGML_OP_SOFT_MAX:
+ {
+ node->n_tasks = MIN(n_threads, ggml_nrows(node->src0));
+ } break;
+ case GGML_OP_ROPE:
+ {
+ node->n_tasks = 1;
+ } break;
+ case GGML_OP_CONV_1D_1S:
+ case GGML_OP_CONV_1D_2S:
+ {
+ // WHISPER
+ node->n_tasks = n_threads;
+
+ GGML_ASSERT(node->src0->ne[3] == 1);
+ GGML_ASSERT(node->src1->ne[2] == 1);
+ GGML_ASSERT(node->src1->ne[3] == 1);
+
+ size_t cur = 0;
+ const int nk = node->src0->ne[0];
+
+ if (node->src0->type == GGML_TYPE_F16 &&
+ node->src1->type == GGML_TYPE_F32) {
+ cur = sizeof(ggml_fp16_t)*(
+ nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
+ ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
+ );
+ } else if (node->src0->type == GGML_TYPE_F32 &&
+ node->src1->type == GGML_TYPE_F32) {
+ cur = sizeof(float)*(
+ nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
+ ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
+ );
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ work_size = MAX(work_size, cur);
+ } break;
+ case GGML_OP_NONE:
+ {
+ node->n_tasks = 1;
+ } break;
+ case GGML_OP_COUNT:
+ {
+ assert(false);
+ } break;
+ };
+ }
+
+ if (cgraph->work != NULL && work_size > cgraph->work_size) {
+ assert(false); // TODO: better handling
+ }
+
+ if (work_size > 0 && cgraph->work == NULL) {
+ cgraph->work_size = work_size + CACHE_LINE_SIZE*(n_threads - 1);
+
+ GGML_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, cgraph->work_size);
+ cgraph->work = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cgraph->work_size);
+ }
+ }
+
+ const int64_t perf_start_cycles = ggml_perf_cycles();
+ const int64_t perf_start_time_us = ggml_perf_time_us();
+
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, i, cgraph->n_nodes);
+
+ struct ggml_tensor * node = cgraph->nodes[i];
+
+ // TODO: this could be used to avoid unnecessary computations, but it needs to be improved
+ //if (node->grad == NULL && node->perf_runs > 0) {
+ // continue;
+ //}
+
+ const int64_t perf_node_start_cycles = ggml_perf_cycles();
+ const int64_t perf_node_start_time_us = ggml_perf_time_us();
+
+ // INIT
+ struct ggml_compute_params params = {
+ /*.type =*/ GGML_TASK_INIT,
+ /*.ith =*/ 0,
+ /*.nth =*/ n_threads,
+ /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
+ /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
+ };
+
+ ggml_compute_forward(¶ms, node);
+
+ // COMPUTE
+ if (node->n_tasks > 1) {
+ if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
+ atomic_store(&state_shared.has_work, false);
+ }
+
+ while (atomic_load(&state_shared.has_work)) {
+ ggml_lock_lock (&state_shared.spin);
+ ggml_lock_unlock(&state_shared.spin);
+ }
+
+ // launch thread pool
+ for (int j = 0; j < n_threads - 1; j++) {
+ workers[j].params = (struct ggml_compute_params) {
+ .type = GGML_TASK_COMPUTE,
+ .ith = j + 1,
+ .nth = n_threads,
+ .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
+ .wdata = cgraph->work ? cgraph->work->data : NULL,
+ };
+ workers[j].node = node;
+ }
+
+ atomic_fetch_sub(&state_shared.n_ready, 1);
+
+ while (atomic_load(&state_shared.n_ready) > 0) {
+ ggml_lock_lock (&state_shared.spin);
+ ggml_lock_unlock(&state_shared.spin);
+ }
+
+ atomic_store(&state_shared.has_work, true);
+ }
+
+ params.type = GGML_TASK_COMPUTE;
+ ggml_compute_forward(¶ms, node);
+
+ // wait for thread pool
+ if (node->n_tasks > 1) {
+ if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
+ atomic_store(&state_shared.has_work, false);
+ }
+
+ while (atomic_load(&state_shared.has_work)) {
+ ggml_lock_lock (&state_shared.spin);
+ ggml_lock_unlock(&state_shared.spin);
+ }
+
+ atomic_fetch_sub(&state_shared.n_ready, 1);
+
+ while (atomic_load(&state_shared.n_ready) != 0) {
+ ggml_lock_lock (&state_shared.spin);
+ ggml_lock_unlock(&state_shared.spin);
+ }
+ }
+
+ // FINALIZE
+ if (node->n_tasks > 1) {
+ if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
+ atomic_store(&state_shared.has_work, false);
+ }
+
+ while (atomic_load(&state_shared.has_work)) {
+ ggml_lock_lock (&state_shared.spin);
+ ggml_lock_unlock(&state_shared.spin);
+ }
+
+ // launch thread pool
+ for (int j = 0; j < n_threads - 1; j++) {
+ workers[j].params = (struct ggml_compute_params) {
+ .type = GGML_TASK_FINALIZE,
+ .ith = j + 1,
+ .nth = n_threads,
+ .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
+ .wdata = cgraph->work ? cgraph->work->data : NULL,
+ };
+ workers[j].node = node;
+ }
+
+ atomic_fetch_sub(&state_shared.n_ready, 1);
+
+ while (atomic_load(&state_shared.n_ready) > 0) {
+ ggml_lock_lock (&state_shared.spin);
+ ggml_lock_unlock(&state_shared.spin);
+ }
+
+ atomic_store(&state_shared.has_work, true);
+ }
+
+ params.type = GGML_TASK_FINALIZE;
+ ggml_compute_forward(¶ms, node);
+
+ // wait for thread pool
+ if (node->n_tasks > 1) {
+ if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
+ atomic_store(&state_shared.has_work, false);
+ }
+
+ while (atomic_load(&state_shared.has_work)) {
+ ggml_lock_lock (&state_shared.spin);
+ ggml_lock_unlock(&state_shared.spin);
+ }
+
+ atomic_fetch_sub(&state_shared.n_ready, 1);
+
+ while (atomic_load(&state_shared.n_ready) != 0) {
+ ggml_lock_lock (&state_shared.spin);
+ ggml_lock_unlock(&state_shared.spin);
+ }
+ }
+
+ // performance stats (node)
+ {
+ int64_t perf_cycles_cur = ggml_perf_cycles() - perf_node_start_cycles;
+ int64_t perf_time_us_cur = ggml_perf_time_us() - perf_node_start_time_us;
+
+ node->perf_runs++;
+ node->perf_cycles += perf_cycles_cur;
+ node->perf_time_us += perf_time_us_cur;
+ }
+ }
+
+ // join thread pool
+ if (n_threads > 1) {
+ atomic_store(&state_shared.stop, true);
+ atomic_store(&state_shared.has_work, true);
+
+ for (int j = 0; j < n_threads - 1; j++) {
+ int rc = pthread_join(workers[j].thrd, NULL);
+ assert(rc == 0);
+ UNUSED(rc);
+ }
+
+ ggml_lock_destroy(&state_shared.spin);
+ }
+
+ // performance stats (graph)
+ {
+ int64_t perf_cycles_cur = ggml_perf_cycles() - perf_start_cycles;
+ int64_t perf_time_us_cur = ggml_perf_time_us() - perf_start_time_us;
+
+ cgraph->perf_runs++;
+ cgraph->perf_cycles += perf_cycles_cur;
+ cgraph->perf_time_us += perf_time_us_cur;
+
+ GGML_PRINT_DEBUG("%s: perf (%d) - cpu = %.3f / %.3f ms, wall = %.3f / %.3f ms\n",
+ __func__, cgraph->perf_runs,
+ (double) perf_cycles_cur / (double) ggml_cycles_per_ms(),
+ (double) cgraph->perf_cycles / (double) ggml_cycles_per_ms() / (double) cgraph->perf_runs,
+ (double) perf_time_us_cur / 1000.0,
+ (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs);
+ }
+}
+
+void ggml_graph_reset(struct ggml_cgraph * cgraph) {
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ struct ggml_tensor * grad = cgraph->grads[i];
+
+ if (grad) {
+ ggml_set_zero(grad);
+ }
+ }
+}
+
+void ggml_graph_print(const struct ggml_cgraph * cgraph) {
+ int64_t perf_total_per_op_us[GGML_OP_COUNT] = {0};
+
+ GGML_PRINT("=== GRAPH ===\n");
+
+ GGML_PRINT_DEBUG("n_threads = %d\n", cgraph->n_threads);
+ GGML_PRINT_DEBUG("total work size = %zu bytes\n",cgraph->work_size);
+
+ GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes);
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ struct ggml_tensor * node = cgraph->nodes[i];
+
+ perf_total_per_op_us[node->op] += node->perf_time_us;
+
+ GGML_PRINT(" - %3d: [ %6d, %6d] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
+ i,
+ node->ne[0], node->ne[1],
+ GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
+ (double) node->perf_cycles / (double) ggml_cycles_per_ms(),
+ (double) node->perf_cycles / (double) ggml_cycles_per_ms() / (double) node->perf_runs,
+ (double) node->perf_time_us / 1000.0,
+ (double) node->perf_time_us / 1000.0 / node->perf_runs);
+ }
+
+ GGML_PRINT("n_leafs = %d\n", cgraph->n_leafs);
+ for (int i = 0; i < cgraph->n_leafs; i++) {
+ struct ggml_tensor * node = cgraph->leafs[i];
+
+ GGML_PRINT(" - %3d: [ %6d, %6d] %8s\n",
+ i,
+ node->ne[0], node->ne[1],
+ GGML_OP_LABEL[node->op]);
+ }
+
+ for (int i = 0; i < GGML_OP_COUNT; i++) {
+ GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_LABEL[i], (double) perf_total_per_op_us[i] / 1000.0);
+ }
+
+ GGML_PRINT("========================================\n");
+}
+
+// check if node is part of the graph
+bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
+ if (cgraph == NULL) {
+ return true;
+ }
+
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ if (cgraph->nodes[i] == node) {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ struct ggml_tensor * parent = cgraph->nodes[i];
+
+ if (parent->grad == node) {
+ return parent;
+ }
+ }
+
+ return NULL;
+}
+
+void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) {
+ char color[16];
+
+ FILE * fp = fopen(filename, "w");
+ assert(fp);
+
+ fprintf(fp, "digraph G {\n");
+ fprintf(fp, " newrank = true;\n");
+ fprintf(fp, " rankdir = LR;\n");
+
+ for (int i = 0; i < gb->n_nodes; i++) {
+ struct ggml_tensor * node = gb->nodes[i];
+
+ if (ggml_graph_get_parent(gb, node) != NULL) {
+ continue;
+ }
+
+ if (node->is_param) {
+ snprintf(color, sizeof(color), "yellow");
+ } else if (node->grad) {
+ if (ggml_graph_find(gf, node)) {
+ snprintf(color, sizeof(color), "green");
+ } else {
+ snprintf(color, sizeof(color), "lightblue");
+ }
+ } else {
+ snprintf(color, sizeof(color), "white");
+ }
+
+ fprintf(fp, " \"%p\" [ \
+style = filled; fillcolor = %s; shape = record; \
+label=\"%d [%d, %d] | <x>%s",
+ (void *) node, color,
+ i, node->ne[0], node->ne[1],
+ GGML_OP_SYMBOL[node->op]);
+
+ if (node->grad) {
+ fprintf(fp, " | <g>%s\"; ]\n", GGML_OP_SYMBOL[node->grad->op]);
+ } else {
+ fprintf(fp, "\"; ]\n");
+ }
+ }
+
+ for (int i = 0; i < gb->n_leafs; i++) {
+ struct ggml_tensor * node = gb->leafs[i];
+
+ snprintf(color, sizeof(color), "pink");
+
+ if (ggml_nelements(node) == 1) {
+ fprintf(fp, " \"%p\" [ \
+style = filled; fillcolor = %s; shape = record; \
+label=\"<x>%.1e\"; ]\n",
+ (void *) node, color, ggml_get_f32_1d(node, 0));
+ } else {
+ fprintf(fp, " \"%p\" [ \
+style = filled; fillcolor = %s; shape = record; \
+label=\"<x>CONST %d [%d, %d]\"; ]\n",
+ (void *) node, color,
+ i, node->ne[0], node->ne[1]);
+ }
+ }
+
+ for (int i = 0; i < gb->n_nodes; i++) {
+ struct ggml_tensor * node = gb->nodes[i];
+
+ struct ggml_tensor * parent = ggml_graph_get_parent(gb, node);
+
+ if (node->src0) {
+ struct ggml_tensor * parent0 = ggml_graph_get_parent(gb, node->src0);
+
+ fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"x\"; ]\n",
+ parent0 ? (void *) parent0 : (void *) node->src0,
+ parent0 ? "g" : "x",
+ parent ? (void *) parent : (void *) node,
+ parent ? "g" : "x",
+ parent ? "empty" : "vee",
+ parent ? "dashed" : "solid");
+ }
+
+ if (node->src1) {
+ struct ggml_tensor * parent1 = ggml_graph_get_parent(gb, node->src1);
+
+ fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"y\"; ]\n",
+ parent1 ? (void *) parent1 : (void *) node->src1,
+ parent1 ? "g" : "x",
+ parent ? (void *) parent : (void *) node,
+ parent ? "g" : "x",
+ parent ? "empty" : "vee",
+ parent ? "dashed" : "solid");
+ }
+ }
+
+ for (int i = 0; i < gb->n_leafs; i++) {
+ struct ggml_tensor * node = gb->leafs[i];
+
+ if (node->src0) {
+ fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"x\"; ]\n",
+ (void *) node->src0, "x",
+ (void *) node, "x");
+ }
+
+ if (node->src1) {
+ fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"y\"; ]\n",
+ (void *) node->src1, "x",
+ (void *) node, "x");
+ }
+ }
+
+ fprintf(fp, "}\n");
+
+ fclose(fp);
+
+ GGML_PRINT("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const float * x) {
+ int i = 0;
+ for (int p = 0; p < np; ++p) {
+ const int ne = ggml_nelements(ps[p]) ;
+ // TODO: add function to set tensor from array
+ for (int j = 0; j < ne; ++j) {
+ ggml_set_f32_1d(ps[p], j, x[i++]);
+ }
+ }
+}
+
+void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float * x) {
+ int i = 0;
+ for (int p = 0; p < np; ++p) {
+ const int ne = ggml_nelements(ps[p]) ;
+ // TODO: add function to get all elements at once
+ for (int j = 0; j < ne; ++j) {
+ x[i++] = ggml_get_f32_1d(ps[p], j);
+ }
+ }
+}
+
+void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) {
+ int i = 0;
+ for (int p = 0; p < np; ++p) {
+ const int ne = ggml_nelements(ps[p]) ;
+ // TODO: add function to get all elements at once
+ for (int j = 0; j < ne; ++j) {
+ g[i++] = ggml_get_f32_1d(ps[p]->grad, j);
+ }
+ }
+}
+
+//
+// ADAM
+//
+// ref: https://arxiv.org/pdf/1412.6980.pdf
+//
+
+enum ggml_opt_result ggml_opt_adam(
+ struct ggml_context * ctx,
+ struct ggml_opt_params params,
+ struct ggml_tensor * f,
+ struct ggml_cgraph * gf,
+ struct ggml_cgraph * gb) {
+ assert(ggml_is_scalar(f));
+
+ gf->n_threads = params.n_threads;
+ gb->n_threads = params.n_threads;
+
+ // these will store the parameters we want to optimize
+ struct ggml_tensor * ps[GGML_MAX_PARAMS];
+
+ int np = 0;
+ int nx = 0;
+ for (int i = 0; i < gf->n_nodes; ++i) {
+ if (gf->nodes[i]->is_param) {
+ GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op);
+
+ assert(np < GGML_MAX_PARAMS);
+
+ ps[np++] = gf->nodes[i];
+ nx += ggml_nelements(gf->nodes[i]);
+ }
+ }
+
+ // constants
+ const float alpha = params.adam.alpha;
+ const float beta1 = params.adam.beta1;
+ const float beta2 = params.adam.beta2;
+ const float eps = params.adam.eps;
+
+ float * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // view of the parameters
+ float * g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient
+ float * g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient squared
+ float * m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment
+ float * v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment
+ float * mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment hat
+ float * vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment hat
+
+ float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values
+
+ // initialize
+ ggml_vec_set_f32(nx, m, 0.0f);
+ ggml_vec_set_f32(nx, v, 0.0f);
+
+ // update view
+ ggml_opt_get_params(np, ps, x);
+
+ // compute the function value
+ ggml_graph_reset (gf);
+ ggml_set_f32 (f->grad, 1.0f);
+ ggml_graph_compute(ctx, gb);
+
+ float fx_prev = ggml_get_f32_1d(f, 0);
+ if (pf) {
+ pf[0] = fx_prev;
+ }
+
+ int n_no_improvement = 0;
+ float fx_best = fx_prev;
+
+ // run the optimizer
+ for (int t = 0; t < params.adam.n_iter; ++t) {
+ GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
+
+ GGML_PRINT_DEBUG ("f = %10.6f\n", ggml_get_f32_1d(f, 0));
+ GGML_PRINT_DEBUG_5("df/dx0 = %10.6f\n", ggml_get_f32_1d(ps[0]->grad, 0));
+ GGML_PRINT_DEBUG_5("df/dx1 = %10.6f\n", ggml_get_f32_1d(ps[1]->grad, 0));
+
+ for (int i = 0; i < np; ++i) {
+ GGML_PRINT_DEBUG("param %d: %10.6f, g = %10.6f\n", i,
+ ggml_get_f32_1d(ps[i], 0), ggml_get_f32_1d(ps[i]->grad, 0));
+ }
+
+ const int64_t t_start_wall = ggml_time_us();
+ const int64_t t_start_cpu = ggml_cycles();
+ UNUSED(t_start_wall);
+ UNUSED(t_start_cpu);
+
+ {
+ // update the gradient
+ ggml_opt_get_grad(np, ps, g1);
+
+ // m_t = beta1*m_t-1 + (1 - beta1)*g_t
+ ggml_vec_scale_f32(nx, m, beta1);
+ ggml_vec_mad_f32 (nx, m, g1, 1.0f - beta1);
+
+ // g2 = g1^2
+ ggml_vec_sqr_f32 (nx, g2, g1);
+
+ // v_t = beta2*v_t-1 + (1 - beta2)*g_t^2
+ ggml_vec_scale_f32(nx, v, beta2);
+ ggml_vec_mad_f32 (nx, v, g2, 1.0f - beta2);
+
+ // m^hat = m_t / (1 - beta1^t)
+ // v^hat = v_t / (1 - beta2^t)
+ // x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps)
+ ggml_vec_cpy_f32 (nx, mh, m);
+ ggml_vec_cpy_f32 (nx, vh, v);
+
+ ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, t + 1)));
+ ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, t + 1)));
+
+ ggml_vec_sqrt_f32 (nx, vh, vh);
+ ggml_vec_acc1_f32 (nx, vh, eps);
+
+ ggml_vec_div_f32 (nx, mh, mh, vh);
+ ggml_vec_sub_f32 (nx, x, x, mh);
+
+ // update the parameters
+ ggml_opt_set_params(np, ps, x);
+ }
+
+ ggml_graph_reset (gf);
+ ggml_set_f32 (f->grad, 1.0f);
+ ggml_graph_compute(ctx, gb);
+
+ const float fx = ggml_get_f32_1d(f, 0);
+
+ // check convergence
+ if (fabsf(fx - fx_prev)/fx < params.adam.eps_f) {
+ GGML_PRINT_DEBUG("converged\n");
+
+ return GGML_OPT_OK;
+ }
+
+ // delta-based convergence test
+ if (pf != NULL) {
+ // need at least params.past iterations to start checking for convergence
+ if (params.past <= t) {
+ const float rate = (pf[t%params.past] - fx)/fx;
+
+ if (fabs(rate) < params.delta) {
+ return GGML_OPT_OK;
+ }
+ }
+
+ pf[t%params.past] = fx;
+ }
+
+ // check for improvement
+ if (params.max_no_improvement > 0) {
+ if (fx_best > fx) {
+ fx_best = fx;
+ n_no_improvement = 0;
+ } else {
+ ++n_no_improvement;
+
+ if (n_no_improvement >= params.max_no_improvement) {
+ return GGML_OPT_OK;
+ }
+ }
+ }
+
+ fx_prev = fx;
+
+ {
+ const int64_t t_end_cpu = ggml_cycles();
+ GGML_PRINT_DEBUG("time iter: %5.3f s\n", (t_end_cpu - t_start_cpu)/CLOCKS_PER_SEC);
+ UNUSED(t_end_cpu);
+
+ const int64_t t_end_wall = ggml_time_us();
+ GGML_PRINT_DEBUG("wall time iter: %5.3f s\n", (t_end_wall - t_start_wall)/1e6);
+ UNUSED(t_end_wall);
+ }
+ }
+
+ return GGML_OPT_DID_NOT_CONVERGE;
+}
+
+//
+// L-BFGS
+//
+// the L-BFGS implementation below is based on the following implementation:
+//
+// https://github.com/chokkan/liblbfgs
+//
+
+struct ggml_lbfgs_iteration_data {
+ float alpha;
+ float ys;
+ float * s;
+ float * y;
+};
+
+static enum ggml_opt_result linesearch_backtracking(
+ struct ggml_context * ctx,
+ const struct ggml_opt_params * params,
+ int nx,
+ float * x,
+ float * fx,
+ float * g,
+ float * d,
+ float * step,
+ const float * xp,
+ struct ggml_tensor * f,
+ struct ggml_cgraph * gf,
+ struct ggml_cgraph * gb,
+ const int np,
+ struct ggml_tensor * ps[]) {
+ int count = 0;
+
+ float width = 0.0f;
+ float dg = 0.0f;
+ float finit = 0.0f;
+ float dginit = 0.0f;
+ float dgtest = 0.0f;
+
+ const float dec = 0.5f;
+ const float inc = 2.1f;
+
+ if (*step <= 0.) {
+ return GGML_LINESEARCH_INVALID_PARAMETERS;
+ }
+
+ // compute the initial gradient in the search direction
+ ggml_vec_dot_f32(nx, &dginit, g, d);
+
+ // make sure that d points to a descent direction
+ if (0 < dginit) {
+ return GGML_LINESEARCH_FAIL;
+ }
+
+ // initialize local variables
+ finit = *fx;
+ dgtest = params->lbfgs.ftol*dginit;
+
+ while (true) {
+ ggml_vec_cpy_f32(nx, x, xp);
+ ggml_vec_mad_f32(nx, x, d, *step);
+
+ // evaluate the function and gradient values
+ {
+ ggml_opt_set_params(np, ps, x);
+
+ ggml_graph_reset (gf);
+ ggml_set_f32 (f->grad, 1.0f);
+ ggml_graph_compute(ctx, gb);
+
+ ggml_opt_get_grad(np, ps, g);
+
+ *fx = ggml_get_f32_1d(f, 0);
+ }
+
+ ++count;
+
+ if (*fx > finit + (*step)*dgtest) {
+ width = dec;
+ } else {
+ // Armijo condition is satisfied
+ if (params->lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_ARMIJO) {
+ return count;
+ }
+
+ ggml_vec_dot_f32(nx, &dg, g, d);
+
+ // check the Wolfe condition
+ if (dg < params->lbfgs.wolfe * dginit) {
+ width = inc;
+ } else {
+ if(params->lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE) {
+ // regular Wolfe conditions
+ return count;
+ }
+
+ if(dg > -params->lbfgs.wolfe*dginit) {
+ width = dec;
+ } else {
+ // strong Wolfe condition (GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE)
+ return count;
+ }
+ return count;
+ }
+ }
+
+ if (*step < params->lbfgs.min_step) {
+ return GGML_LINESEARCH_MINIMUM_STEP;
+ }
+ if (*step > params->lbfgs.max_step) {
+ return GGML_LINESEARCH_MAXIMUM_STEP;
+ }
+ if (params->lbfgs.max_linesearch <= count) {
+ return GGML_LINESEARCH_MAXIMUM_ITERATIONS;
+ }
+
+ (*step) *= width;
+ }
+
+ return GGML_LINESEARCH_FAIL;
+}
+
+enum ggml_opt_result ggml_opt_lbfgs(
+ struct ggml_context * ctx,
+ struct ggml_opt_params params,
+ struct ggml_tensor * f,
+ struct ggml_cgraph * gf,
+ struct ggml_cgraph * gb) {
+ if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE ||
+ params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) {
+ if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1. <= params.lbfgs.wolfe) {
+ return GGML_OPT_INVALID_WOLFE;
+ }
+ }
+
+ gf->n_threads = params.n_threads;
+ gb->n_threads = params.n_threads;
+
+ const int m = params.lbfgs.m;
+
+ // these will store the parameters we want to optimize
+ struct ggml_tensor * ps[GGML_MAX_PARAMS];
+
+ int np = 0;
+ int nx = 0;
+ for (int i = 0; i < gf->n_nodes; ++i) {
+ if (gf->nodes[i]->is_param) {
+ GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op);
+
+ assert(np < GGML_MAX_PARAMS);
+
+ ps[np++] = gf->nodes[i];
+ nx += ggml_nelements(gf->nodes[i]);
+ }
+ }
+
+ float * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current parameters
+ float * xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous parameters
+ float * g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current gradient
+ float * gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous gradient
+ float * d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // search direction
+
+ float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values
+
+ float fx = 0.0f; // cost function value
+ float xnorm = 0.0f; // ||x||
+ float gnorm = 0.0f; // ||g||
+ float step = 0.0f;
+
+ // initialize x from the graph nodes
+ ggml_opt_get_params(np, ps, x);
+
+ // the L-BFGS memory
+ struct ggml_lbfgs_iteration_data * lm = alloca(sizeof(struct ggml_lbfgs_iteration_data)*m);
+
+ for (int i = 0; i < m; ++i) {
+ lm[i].alpha = 0.0f;
+ lm[i].ys = 0.0f;
+ lm[i].s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
+ lm[i].y = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
+ }
+
+ // evaluate the function value and its gradient
+ {
+ ggml_opt_set_params(np, ps, x);
+
+ ggml_graph_reset (gf);
+ ggml_set_f32 (f->grad, 1.0f);
+ ggml_graph_compute(ctx, gb);
+
+ ggml_opt_get_grad(np, ps, g);
+
+ fx = ggml_get_f32_1d(f, 0);
+ }
+
+ if (pf) {
+ pf[0] = fx;
+ }
+
+ float fx_best = fx;
+
+ // search direction = -gradient
+ ggml_vec_neg_f32(nx, d, g);
+
+ // ||x||, ||g||
+ ggml_vec_norm_f32(nx, &xnorm, x);
+ ggml_vec_norm_f32(nx, &gnorm, g);
+
+ if (xnorm < 1.0f) {
+ xnorm = 1.0f;
+ }
+
+ // already optimized
+ if (gnorm/xnorm <= params.lbfgs.eps) {
+ return GGML_OPT_OK;
+ }
+
+ // initial step
+ ggml_vec_norm_inv_f32(nx, &step, d);
+
+ int j = 0;
+ int k = 1;
+ int ls = 0;
+ int end = 0;
+ int bound = 0;
+ int n_no_improvement = 0;
+
+ float ys = 0.0f;
+ float yy = 0.0f;
+ float beta = 0.0f;
+
+ while (true) {
+ // store the current position and gradient vectors
+ ggml_vec_cpy_f32(nx, xp, x);
+ ggml_vec_cpy_f32(nx, gp, g);
+
+ ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d, &step, xp, f, gf, gb, np, ps);
+
+ if (ls < 0) {
+ // linesearch failed - go back to the previous point and return
+ ggml_vec_cpy_f32(nx, x, xp);
+ ggml_vec_cpy_f32(nx, g, gp);
+
+ return ls;
+ }
+
+ ggml_vec_norm_f32(nx, &xnorm, x);
+ ggml_vec_norm_f32(nx, &gnorm, g);
+
+ GGML_PRINT_DEBUG("f = %10.6f\n", ggml_get_f32_1d(f, 0));
+
+ if (xnorm < 1.0) {
+ xnorm = 1.0;
+ }
+ if (gnorm/xnorm <= params.lbfgs.eps) {
+ // converged
+ return GGML_OPT_OK;
+ }
+
+ // delta-based convergence test
+ if (pf != NULL) {
+ // need at least params.past iterations to start checking for convergence
+ if (params.past <= k) {
+ const float rate = (pf[k%params.past] - fx)/fx;
+
+ if (fabs(rate) < params.delta) {
+ return GGML_OPT_OK;
+ }
+ }
+
+ pf[k%params.past] = fx;
+ }
+
+ // check for improvement
+ if (params.max_no_improvement > 0) {
+ if (fx < fx_best) {
+ fx_best = fx;
+ n_no_improvement = 0;
+ } else {
+ n_no_improvement++;
+
+ if (n_no_improvement >= params.max_no_improvement) {
+ return GGML_OPT_OK;
+ }
+ }
+ }
+
+ if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < k + 1) {
+ // reached the maximum number of iterations
+ return GGML_OPT_DID_NOT_CONVERGE;
+ }
+
+ // update vectors s and y:
+ // s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}.
+ // y_{k+1} = g_{k+1} - g_{k}.
+ //
+ ggml_vec_sub_f32(nx, lm[end].s, x, xp);
+ ggml_vec_sub_f32(nx, lm[end].y, g, gp);
+
+ // compute scalars ys and yy:
+ // ys = y^t \cdot s -> 1 / \rho.
+ // yy = y^t \cdot y.
+ //
+ ggml_vec_dot_f32(nx, &ys, lm[end].y, lm[end].s);
+ ggml_vec_dot_f32(nx, &yy, lm[end].y, lm[end].y);
+
+ lm[end].ys = ys;
+
+ // find new search direction
+ // ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS
+
+ bound = (m <= k) ? m : k;
+ k++;
+ end = (end + 1)%m;
+
+ // initialize search direction with -g
+ ggml_vec_neg_f32(nx, d, g);
+
+ j = end;
+ for (int i = 0; i < bound; ++i) {
+ j = (j + m - 1) % m;
+ // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1}
+ ggml_vec_dot_f32(nx, &lm[j].alpha, lm[j].s, d);
+ lm[j].alpha /= lm[j].ys;
+ // q_{i} = q_{i+1} - \alpha_{i} y_{i}
+ ggml_vec_mad_f32(nx, d, lm[j].y, -lm[j].alpha);
+ }
+
+ ggml_vec_scale_f32(nx, d, ys/yy);
+
+ for (int i = 0; i < bound; ++i) {
+ // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i}
+ ggml_vec_dot_f32(nx, &beta, lm[j].y, d);
+ beta /= lm[j].ys;
+ // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j}
+ ggml_vec_mad_f32(nx, d, lm[j].s, lm[j].alpha - beta);
+ j = (j + 1)%m;
+ }
+
+ step = 1.0;
+ }
+
+ return GGML_OPT_DID_NOT_CONVERGE;
+}
+
+struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
+ struct ggml_opt_params result;
+
+ switch (type) {
+ case GGML_OPT_ADAM:
+ {
+ result = (struct ggml_opt_params) {
+ .type = GGML_OPT_ADAM,
+ .n_threads = 1,
+ .past = 0,
+ .delta = 1e-5f,
+
+ .max_no_improvement = 100,
+
+ .print_forward_graph = true,
+ .print_backward_graph = true,
+
+ .adam = {
+ .n_iter = 10000,
+ .alpha = 0.001f,
+ .beta1 = 0.9f,
+ .beta2 = 0.999f,
+ .eps = 1e-8f,
+ .eps_f = 1e-5f,
+ .eps_g = 1e-3f,
+ },
+ };
+ } break;
+ case GGML_OPT_LBFGS:
+ {
+ result = (struct ggml_opt_params) {
+ .type = GGML_OPT_LBFGS,
+ .n_threads = 1,
+ .past = 0,
+ .delta = 1e-5f,
+
+ .max_no_improvement = 0,
+
+ .print_forward_graph = true,
+ .print_backward_graph = true,
+
+ .lbfgs = {
+ .m = 6,
+ .n_iter = 100,
+ .max_linesearch = 20,
+
+ .eps = 1e-5f,
+ .ftol = 1e-4f,
+ .wolfe = 0.9f,
+ .min_step = 1e-20f,
+ .max_step = 1e+20f,
+
+ .linesearch = GGML_LINESEARCH_DEFAULT,
+ },
+ };
+ } break;
+ }
+
+ return result;
+}
+
+enum ggml_opt_result ggml_opt(
+ struct ggml_context * ctx,
+ struct ggml_opt_params params,
+ struct ggml_tensor * f) {
+ bool free_ctx = false;
+ if (ctx == NULL) {
+ struct ggml_init_params params_ctx = {
+ .mem_size = 16*1024*1024,
+ .mem_buffer = NULL,
+ };
+
+ ctx = ggml_init(params_ctx);
+ if (ctx == NULL) {
+ return GGML_OPT_NO_CONTEXT;
+ }
+
+ free_ctx = true;
+ }
+
+ enum ggml_opt_result result = GGML_OPT_OK;
+
+ // build forward + backward compute graphs
+ struct ggml_cgraph gf = ggml_build_forward (f);
+ struct ggml_cgraph gb = ggml_build_backward(ctx, &gf, false);
+
+ switch (params.type) {
+ case GGML_OPT_ADAM:
+ {
+ result = ggml_opt_adam(ctx, params, f, &gf, &gb);
+ } break;
+ case GGML_OPT_LBFGS:
+ {
+ result = ggml_opt_lbfgs(ctx, params, f, &gf, &gb);
+ } break;
+ }
+
+ if (params.print_forward_graph) {
+ ggml_graph_print (&gf);
+ ggml_graph_dump_dot(&gf, NULL, "opt-forward.dot");
+ }
+
+ if (params.print_backward_graph) {
+ ggml_graph_print (&gb);
+ ggml_graph_dump_dot(&gb, &gf, "opt-backward.dot");
+ }
+
+ if (free_ctx) {
+ ggml_free(ctx);
+ }
+
+ return result;
+}
+
+////////////////////////////////////////////////////////////////////////////////
--- /dev/null
+#pragma once
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#include <stdint.h>
+#include <stddef.h>
+#include <stdbool.h>
+
+#define GGML_MAX_DIMS 4
+#define GGML_MAX_NODES 4096
+#define GGML_MAX_PARAMS 16
+#define GGML_MAX_CONTEXTS 16
+
+#ifdef __ARM_NEON
+// we use the built-in 16-bit float type
+typedef __fp16 ggml_fp16_t;
+#else
+typedef uint16_t ggml_fp16_t;
+#endif
+
+float ggml_fp16_to_fp32(ggml_fp16_t x);
+ggml_fp16_t ggml_fp32_to_fp16(float x);
+
+struct ggml_object;
+struct ggml_context;
+
+enum ggml_type {
+ GGML_TYPE_I8,
+ GGML_TYPE_I16,
+ GGML_TYPE_I32,
+ GGML_TYPE_F16,
+ GGML_TYPE_F32,
+ GGML_TYPE_COUNT,
+};
+
+enum ggml_op {
+ GGML_OP_NONE = 0,
+
+ GGML_OP_DUP,
+ GGML_OP_ADD,
+ GGML_OP_SUB,
+ GGML_OP_MUL,
+ GGML_OP_DIV,
+ GGML_OP_SQR,
+ GGML_OP_SQRT,
+ GGML_OP_SUM,
+ GGML_OP_MEAN,
+ GGML_OP_REPEAT,
+ GGML_OP_ABS,
+ GGML_OP_SGN,
+ GGML_OP_NEG,
+ GGML_OP_STEP,
+ GGML_OP_RELU,
+ GGML_OP_GELU,
+ GGML_OP_NORM, // normalize
+
+ GGML_OP_MUL_MAT,
+
+ GGML_OP_SCALE,
+ GGML_OP_CPY,
+ GGML_OP_RESHAPE,
+ GGML_OP_VIEW,
+ GGML_OP_PERMUTE,
+ GGML_OP_TRANSPOSE,
+ GGML_OP_GET_ROWS,
+ GGML_OP_DIAG_MASK_INF,
+ GGML_OP_SOFT_MAX,
+ GGML_OP_ROPE,
+ GGML_OP_CONV_1D_1S,
+ GGML_OP_CONV_1D_2S,
+
+ GGML_OP_COUNT,
+};
+
+// n-dimensional tensor
+struct ggml_tensor {
+ enum ggml_type type;
+
+ int n_dims;
+ int ne[GGML_MAX_DIMS]; // number of elements
+ size_t nb[GGML_MAX_DIMS]; // stride in bytes:
+ // nb[0] = sizeof(type)
+ // nb[1] = nb[0] * ne[0] + padding
+ // nb[i] = nb[i-1] * ne[i-1]
+
+ // compute data
+ enum ggml_op op;
+
+ bool is_param;
+
+ struct ggml_tensor * grad;
+ struct ggml_tensor * src0;
+ struct ggml_tensor * src1;
+
+ // thread scheduling
+ int n_tasks;
+
+ // performance
+ int perf_runs;
+ int64_t perf_cycles;
+ int64_t perf_time_us;
+
+ void * data;
+ char pad[8];
+};
+
+// computation graph
+struct ggml_cgraph {
+ int n_nodes;
+ int n_leafs;
+ int n_threads;
+
+ size_t work_size;
+ struct ggml_tensor * work;
+
+ struct ggml_tensor * nodes[GGML_MAX_NODES];
+ struct ggml_tensor * grads[GGML_MAX_NODES];
+ struct ggml_tensor * leafs[GGML_MAX_NODES];
+
+ // performance
+ int perf_runs;
+ int64_t perf_cycles;
+ int64_t perf_time_us;
+};
+
+struct ggml_init_params {
+ // memory pool
+ size_t mem_size; // bytes
+ void * mem_buffer; // if NULL, memory will be allocated internally
+};
+
+int64_t ggml_time_ms(void);
+int64_t ggml_time_us(void);
+int64_t ggml_cycles(void);
+int64_t ggml_cycles_per_ms(void);
+
+void ggml_print_object (const struct ggml_object * obj);
+void ggml_print_objects(const struct ggml_context * ctx);
+
+int ggml_nelements(const struct ggml_tensor * tensor);
+size_t ggml_nbytes (const struct ggml_tensor * tensor);
+
+size_t ggml_type_size (enum ggml_type type);
+size_t ggml_element_size(const struct ggml_tensor * tensor);
+
+struct ggml_context * ggml_init(struct ggml_init_params params);
+void ggml_free(struct ggml_context * ctx);
+
+size_t ggml_used_mem(const struct ggml_context * ctx);
+
+struct ggml_tensor * ggml_new_tensor(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int n_dims,
+ const int *ne);
+
+struct ggml_tensor * ggml_new_tensor_1d(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int ne0);
+
+struct ggml_tensor * ggml_new_tensor_2d(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int ne0,
+ int ne1);
+
+struct ggml_tensor * ggml_new_tensor_3d(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int ne0,
+ int ne1,
+ int ne2);
+
+struct ggml_tensor * ggml_new_tensor_4d(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int ne0,
+ int ne1,
+ int ne2,
+ int ne3);
+
+struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
+
+struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
+struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src);
+
+struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
+struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
+
+float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
+void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
+
+ void * ggml_get_data (const struct ggml_tensor * tensor);
+float * ggml_get_data_f32(const struct ggml_tensor * tensor);
+
+//
+// operations on tensors with backpropagation
+//
+
+struct ggml_tensor * ggml_dup(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+struct ggml_tensor * ggml_add(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+struct ggml_tensor * ggml_sub(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+struct ggml_tensor * ggml_mul(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+struct ggml_tensor * ggml_div(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+struct ggml_tensor * ggml_sqr(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+struct ggml_tensor * ggml_sqrt(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+// return scalar
+// TODO: compute sum along rows
+struct ggml_tensor * ggml_sum(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+// mean along rows
+struct ggml_tensor * ggml_mean(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+// if a is the same shape as b, and a is not parameter, return a
+// otherwise, return a new tensor: repeat(a) to fit in b
+struct ggml_tensor * ggml_repeat(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+struct ggml_tensor * ggml_abs(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+struct ggml_tensor * ggml_sgn(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+struct ggml_tensor * ggml_neg(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+struct ggml_tensor * ggml_step(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+struct ggml_tensor * ggml_relu(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+// TODO: double-check this computation is correct
+struct ggml_tensor * ggml_gelu(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+// normalize along rows
+// TODO: eps is hardcoded to 1e-5 for now
+struct ggml_tensor * ggml_norm(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+// A: m rows, n columns
+// B: p rows, n columns (i.e. we transpose it internally)
+// result is m columns, p rows
+struct ggml_tensor * ggml_mul_mat(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+//
+// operations on tensors without backpropagation
+//
+
+// in-place, returns view(a)
+struct ggml_tensor * ggml_scale(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+// a -> b, return view(b)
+struct ggml_tensor * ggml_cpy(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+// return view(a), b specifies the new shape
+// TODO: when we start computing gradient, make a copy instead of view
+struct ggml_tensor * ggml_reshape(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+// return view(a)
+// TODO: when we start computing gradient, make a copy instead of view
+struct ggml_tensor * ggml_reshape_2d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int ne0,
+ int ne1);
+
+// return view(a)
+// TODO: when we start computing gradient, make a copy instead of view
+struct ggml_tensor * ggml_reshape_3d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int ne0,
+ int ne1,
+ int ne2);
+
+// offset in bytes
+struct ggml_tensor * ggml_view_1d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int ne0,
+ size_t offset);
+
+struct ggml_tensor * ggml_view_2d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int ne0,
+ int ne1,
+ size_t nb1, // row stride in bytes
+ size_t offset);
+
+struct ggml_tensor * ggml_permute(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int axis0,
+ int axis1,
+ int axis2,
+ int axis3);
+
+// alias for ggml_permute(ctx, a, 1, 0, 2, 3)
+struct ggml_tensor * ggml_transpose(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+struct ggml_tensor * ggml_get_rows(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+// set elements above the diagonal to -INF
+// in-place, returns view(a)
+struct ggml_tensor * ggml_diag_mask_inf(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_past);
+
+// in-place, returns view(a)
+struct ggml_tensor * ggml_soft_max(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+// rotary position embedding
+// in-place, returns view(a)
+// if mode == 1, skip n_past elements
+// TODO: avoid creating a new tensor every time
+struct ggml_tensor * ggml_rope(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_past,
+ int n_dims,
+ int mode);
+
+// padding = 1
+// TODO: we don't support extra parameters for now
+// that's why we are hard-coding the stride, padding, and dilation
+// not great ..
+struct ggml_tensor * ggml_conv_1d_1s(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+struct ggml_tensor * ggml_conv_1d_2s(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+//
+// automatic differentiation
+//
+
+void ggml_set_param(
+ struct ggml_context * ctx,
+ struct ggml_tensor * tensor);
+
+void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
+
+struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
+struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
+
+void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph);
+void ggml_graph_reset (struct ggml_cgraph * cgraph);
+
+// print info and performance information for the graph
+void ggml_graph_print(const struct ggml_cgraph * cgraph);
+
+// dump the graph into a file using the dot format
+void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
+
+//
+// optimization
+//
+
+// optimization methods
+enum ggml_opt_type {
+ GGML_OPT_ADAM,
+ GGML_OPT_LBFGS,
+};
+
+// linesearch methods
+enum ggml_linesearch {
+ GGML_LINESEARCH_DEFAULT = 1,
+
+ GGML_LINESEARCH_BACKTRACKING_ARMIJO = 0,
+ GGML_LINESEARCH_BACKTRACKING_WOLFE = 1,
+ GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2,
+};
+
+// optimization return values
+enum ggml_opt_result {
+ GGML_OPT_OK = 0,
+ GGML_OPT_DID_NOT_CONVERGE,
+ GGML_OPT_NO_CONTEXT,
+ GGML_OPT_INVALID_WOLFE,
+ GGML_OPT_FAIL,
+
+ GGML_LINESEARCH_FAIL = -128,
+ GGML_LINESEARCH_MINIMUM_STEP,
+ GGML_LINESEARCH_MAXIMUM_STEP,
+ GGML_LINESEARCH_MAXIMUM_ITERATIONS,
+ GGML_LINESEARCH_INVALID_PARAMETERS,
+};
+
+// optimization parameters
+//
+// see ggml.c (ggml_opt_default_params) for default values
+//
+struct ggml_opt_params {
+ enum ggml_opt_type type;
+
+ int n_threads;
+
+ // delta-based convergence test
+ //
+ // if past == 0 - disabled
+ // if past > 0:
+ // stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|)
+ //
+ int past;
+ float delta;
+
+ // maximum number of iterations without improvement
+ //
+ // if 0 - disabled
+ // if > 0:
+ // assume convergence if no cost improvement in this number of iterations
+ //
+ int max_no_improvement;
+
+ bool print_forward_graph;
+ bool print_backward_graph;
+
+ union {
+ // ADAM parameters
+ struct {
+ int n_iter;
+
+ float alpha; // learning rate
+ float beta1;
+ float beta2;
+ float eps; // epsilon for numerical stability
+ float eps_f; // epsilon for convergence test
+ float eps_g; // epsilon for convergence test
+ } adam;
+
+ // LBFGS parameters
+ struct {
+ int m; // number of corrections to approximate the inv. Hessian
+ int n_iter;
+ int max_linesearch;
+
+ float eps; // convergence tolerance
+ float ftol; // line search tolerance
+ float wolfe;
+ float min_step;
+ float max_step;
+
+ enum ggml_linesearch linesearch;
+ } lbfgs;
+ };
+};
+
+struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
+
+// optimize the function defined by the tensor f
+enum ggml_opt_result ggml_opt(
+ struct ggml_context * ctx,
+ struct ggml_opt_params params,
+ struct ggml_tensor * f);
+
+#ifdef __cplusplus
+}
+#endif
--- /dev/null
+#include "ggml.h"
+
+// third-party utilities
+// use your favorite implementations
+#define DR_WAV_IMPLEMENTATION
+#include "dr_wav.h"
+
+#include <algorithm>
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <fstream>
+#include <map>
+#include <string>
+#include <thread>
+#include <vector>
+
+enum e_model {
+ MODEL_UNKNOWN,
+ MODEL_TINY,
+ MODEL_BASE,
+ MODEL_SMALL,
+ MODEL_MEDIUM,
+ MODEL_LARGE,
+};
+
+const size_t MB = 1024*1024;
+
+const std::map<e_model, size_t> MEM_REQ_MODEL = {
+ { MODEL_TINY, 100ull*MB },
+ { MODEL_BASE, 190ull*MB },
+ { MODEL_SMALL, 610ull*MB },
+ { MODEL_MEDIUM, 1900ull*MB },
+ { MODEL_LARGE, 3600ull*MB },
+};
+
+const std::map<e_model, size_t> MEM_REQ_ENCODE = {
+ { MODEL_TINY, 80ull*MB },
+ { MODEL_BASE, 128ull*MB },
+ { MODEL_SMALL, 300ull*MB },
+ { MODEL_MEDIUM, 680ull*MB },
+ { MODEL_LARGE, 1100ull*MB },
+};
+
+const std::map<e_model, size_t> MEM_REQ_ENCODE_LAYER = {
+ { MODEL_TINY, 170ull*MB },
+ { MODEL_BASE, 230ull*MB },
+ { MODEL_SMALL, 350ull*MB },
+ { MODEL_MEDIUM, 450ull*MB },
+ { MODEL_LARGE, 570ull*MB },
+};
+
+const std::map<e_model, size_t> MEM_REQ_DECODE = {
+ { MODEL_TINY, 190ull*MB },
+ { MODEL_BASE, 190ull*MB },
+ { MODEL_SMALL, 190ull*MB },
+ { MODEL_MEDIUM, 200ull*MB },
+ { MODEL_LARGE, 200ull*MB },
+};
+
+const std::map<e_model, size_t> MEM_REQ_DECODE_LAYER = {
+ { MODEL_TINY, 32ull*MB },
+ { MODEL_BASE, 44ull*MB },
+ { MODEL_SMALL, 64ull*MB },
+ { MODEL_MEDIUM, 84ull*MB },
+ { MODEL_LARGE, 110ull*MB },
+};
+
+const int SAMPLE_RATE = 16000;
+const int N_FFT = 400;
+const int N_MEL = 80;
+const int HOP_LENGTH = 160;
+const int CHUNK_SIZE = 30; // seconds
+
+struct whisper_mel {
+ int n_len;
+ int n_mel;
+
+ std::vector<float> data;
+};
+
+struct whisper_filters {
+ int32_t n_mel;
+ int32_t n_fft;
+
+ std::vector<float> data;
+};
+
+struct whisper_vocab {
+ using id = int32_t;
+ using token = std::string;
+
+ int n_vocab = 51864;
+
+ std::map<token, id> token_to_id;
+ std::map<id, token> id_to_token;
+
+ id token_eot = 50256;
+ id token_sot = 50257;
+ id token_prev = 50360;
+ id token_solm = 50361; // ??
+ id token_beg = 50363;
+
+ bool is_multilingual() const {
+ return n_vocab == 51865;
+ }
+};
+
+// command-line parameters
+struct whisper_params {
+ int32_t seed = -1; // RNG seed
+ int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
+
+ int32_t max_tokens_per_iter = 64;
+
+ bool verbose = false;
+ bool print_special_tokens = false;
+
+ std::string model = "models/whisper-tiny.en/ggml-model.bin"; // model path
+
+ std::string fname_inp = "default.wav";
+};
+
+void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
+
+bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
+ for (int i = 1; i < argc; i++) {
+ std::string arg = argv[i];
+
+ if (arg == "-s" || arg == "--seed") {
+ params.seed = std::stoi(argv[++i]);
+ } else if (arg == "-t" || arg == "--threads") {
+ params.n_threads = std::stoi(argv[++i]);
+ } else if (arg == "-T" || arg == "--tokens") {
+ params.max_tokens_per_iter = std::stoi(argv[++i]);
+ } else if (arg == "-v" || arg == "--verbose") {
+ params.verbose = true;
+ } else if (arg == "-ps" || arg == "--print_special") {
+ params.print_special_tokens = true;
+ } else if (arg == "-m" || arg == "--model") {
+ params.model = argv[++i];
+ } else if (arg == "-f" || arg == "--file") {
+ params.fname_inp = argv[++i];
+ } else if (arg == "-h" || arg == "--help") {
+ whisper_print_usage(argc, argv, params);
+ exit(0);
+ } else {
+ fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
+ whisper_print_usage(argc, argv, params);
+ exit(0);
+ }
+ }
+
+ return true;
+}
+
+void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
+ fprintf(stderr, "usage: %s [options]\n", argv[0]);
+ fprintf(stderr, "\n");
+ fprintf(stderr, "options:\n");
+ fprintf(stderr, " -h, --help show this help message and exit\n");
+ fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
+ fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
+ fprintf(stderr, " -T N, --tokens N maximum number of tokens to generate per iteration (default: %d)\n", params.max_tokens_per_iter);
+ fprintf(stderr, " -v, --verbose verbose output\n");
+ fprintf(stderr, " -ps, --print_special print special tokens\n");
+ fprintf(stderr, " -m FNAME, --model FNAME\n");
+ fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
+ fprintf(stderr, " -f FNAME, --file FNAME\n");
+ fprintf(stderr, " input WAV file path (default: %s)\n", params.fname_inp.c_str());
+ fprintf(stderr, "\n");
+}
+
+
+// medium
+// hparams: {
+// 'n_mels': 80,
+// 'n_vocab': 51864,
+// 'n_audio_ctx': 1500,
+// 'n_audio_state': 1024,
+// 'n_audio_head': 16,
+// 'n_audio_layer': 24,
+// 'n_text_ctx': 448,
+// 'n_text_state': 1024,
+// 'n_text_head': 16,
+// 'n_text_layer': 24
+// }
+//
+// default hparams (Whisper tiny)
+struct whisper_hparams {
+ int32_t n_vocab = 51864;
+ int32_t n_audio_ctx = 1500;
+ int32_t n_audio_state = 384;
+ int32_t n_audio_head = 6;
+ int32_t n_audio_layer = 4;
+ int32_t n_text_ctx = 448;
+ int32_t n_text_state = 384;
+ int32_t n_text_head = 6;
+ int32_t n_text_layer = 4;
+ int32_t n_mels = 80;
+ int32_t f16 = 1;
+};
+
+// audio encoding layer
+struct whisper_layer_encoder {
+ // encoder.blocks.*.attn_ln
+ struct ggml_tensor * attn_ln_0_w;
+ struct ggml_tensor * attn_ln_0_b;
+
+ // encoder.blocks.*.attn.out
+ struct ggml_tensor * attn_ln_1_w;
+ struct ggml_tensor * attn_ln_1_b;
+
+ // encoder.blocks.*.attn.query
+ struct ggml_tensor * attn_q_w;
+ struct ggml_tensor * attn_q_b;
+
+ // encoder.blocks.*.attn.key
+ struct ggml_tensor * attn_k_w;
+
+ // encoder.blocks.*.attn.value
+ struct ggml_tensor * attn_v_w;
+ struct ggml_tensor * attn_v_b;
+
+ // encoder.blocks.*.mlp_ln
+ struct ggml_tensor * mlp_ln_w;
+ struct ggml_tensor * mlp_ln_b;
+
+ // encoder.blocks.*.mlp.0
+ struct ggml_tensor * mlp_0_w;
+ struct ggml_tensor * mlp_0_b;
+
+ // encoder.blocks.*.mlp.2
+ struct ggml_tensor * mlp_1_w;
+ struct ggml_tensor * mlp_1_b;
+};
+
+// token decoding layer
+struct whisper_layer_decoder {
+ // decoder.blocks.*.attn_ln
+ struct ggml_tensor * attn_ln_0_w;
+ struct ggml_tensor * attn_ln_0_b;
+
+ // decoder.blocks.*.attn.out
+ struct ggml_tensor * attn_ln_1_w;
+ struct ggml_tensor * attn_ln_1_b;
+
+ // decoder.blocks.*.attn.query
+ struct ggml_tensor * attn_q_w;
+ struct ggml_tensor * attn_q_b;
+
+ // decoder.blocks.*.attn.key
+ struct ggml_tensor * attn_k_w;
+
+ // decoder.blocks.*.attn.value
+ struct ggml_tensor * attn_v_w;
+ struct ggml_tensor * attn_v_b;
+
+ // decoder.blocks.*.cross_attn_ln
+ struct ggml_tensor * cross_attn_ln_0_w;
+ struct ggml_tensor * cross_attn_ln_0_b;
+
+ // decoder.blocks.*.cross_attn.out
+ struct ggml_tensor * cross_attn_ln_1_w;
+ struct ggml_tensor * cross_attn_ln_1_b;
+
+ // decoder.blocks.*.cross_attn.query
+ struct ggml_tensor * cross_attn_q_w;
+ struct ggml_tensor * cross_attn_q_b;
+
+ // decoder.blocks.*.cross_attn.key
+ struct ggml_tensor * cross_attn_k_w;
+
+ // decoder.blocks.*.cross_attn.value
+ struct ggml_tensor * cross_attn_v_w;
+ struct ggml_tensor * cross_attn_v_b;
+
+ // decoder.blocks.*.mlp_ln
+ struct ggml_tensor * mlp_ln_w;
+ struct ggml_tensor * mlp_ln_b;
+
+ // decoder.blocks.*.mlp.0
+ struct ggml_tensor * mlp_0_w;
+ struct ggml_tensor * mlp_0_b;
+
+ // decoder.blocks.*.mlp.2
+ struct ggml_tensor * mlp_1_w;
+ struct ggml_tensor * mlp_1_b;
+};
+
+struct whisper_model {
+ e_model type = MODEL_UNKNOWN;
+
+ whisper_hparams hparams;
+ whisper_filters filters;
+
+ // encoder.positional_embedding
+ struct ggml_tensor * e_pe;
+
+ // encoder.conv1
+ struct ggml_tensor * e_conv_1_w;
+ struct ggml_tensor * e_conv_1_b;
+
+ // encoder.conv2
+ struct ggml_tensor * e_conv_2_w;
+ struct ggml_tensor * e_conv_2_b;
+
+ // encoder.ln_post
+ struct ggml_tensor * e_ln_w;
+ struct ggml_tensor * e_ln_b;
+
+ // decoder.positional_embedding
+ struct ggml_tensor * d_pe; // DD
+
+ // decoder.token_embedding
+ struct ggml_tensor * d_te; // DD
+
+ // decoder.ln
+ struct ggml_tensor * d_ln_w; // DD
+ struct ggml_tensor * d_ln_b; // DD
+
+ std::vector<whisper_layer_encoder> layers_encoder;
+ std::vector<whisper_layer_decoder> layers_decoder;
+
+ // key + value memory
+ struct ggml_tensor * memory_k;
+ struct ggml_tensor * memory_v;
+
+ struct ggml_tensor * memory_cross_k;
+ struct ggml_tensor * memory_cross_v;
+
+ //
+ struct ggml_context * ctx;
+ std::map<std::string, struct ggml_tensor *> tensors;
+};
+
+// load the model from a ggml file
+//
+// file format:
+//
+// - hparams
+// - pre-computed mel filters
+// - vocab
+// - weights
+//
+// see the convert-pt-to-ggml.py script for details
+//
+bool whisper_model_load(const std::string & fname, whisper_model & model, whisper_vocab & vocab) {
+ printf("%s: loading model from '%s'\n", __func__, fname.c_str());
+
+ auto fin = std::ifstream(fname, std::ios::binary);
+ if (!fin) {
+ fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
+ return false;
+ }
+
+ // verify magic
+ {
+ uint32_t magic;
+ fin.read((char *) &magic, sizeof(magic));
+ if (magic != 0x67676d6c) {
+ fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
+ return false;
+ }
+ }
+
+ //load hparams
+ {
+ auto & hparams = model.hparams;
+
+ fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
+ fin.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
+ fin.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
+ fin.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
+ fin.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
+ fin.read((char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx));
+ fin.read((char *) &hparams.n_text_state, sizeof(hparams.n_text_state));
+ fin.read((char *) &hparams.n_text_head, sizeof(hparams.n_text_head));
+ fin.read((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer));
+ fin.read((char *) &hparams.n_mels, sizeof(hparams.n_mels));
+ fin.read((char *) &hparams.f16, sizeof(hparams.f16));
+
+ assert(hparams.n_text_state == hparams.n_audio_state);
+
+ if (hparams.n_audio_layer == 4) {
+ model.type = e_model::MODEL_TINY;
+ }
+
+ if (hparams.n_audio_layer == 6) {
+ model.type = e_model::MODEL_BASE;
+ }
+
+ if (hparams.n_audio_layer == 12) {
+ model.type = e_model::MODEL_SMALL;
+ }
+
+ if (hparams.n_audio_layer == 24) {
+ model.type = e_model::MODEL_MEDIUM;
+ }
+
+ if (hparams.n_audio_layer == 32) {
+ model.type = e_model::MODEL_LARGE;
+ }
+
+ printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
+ printf("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
+ printf("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
+ printf("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
+ printf("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
+ printf("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
+ printf("%s: n_text_state = %d\n", __func__, hparams.n_text_state);
+ printf("%s: n_text_head = %d\n", __func__, hparams.n_text_head);
+ printf("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
+ printf("%s: n_mels = %d\n", __func__, hparams.n_mels);
+ printf("%s: f16 = %d\n", __func__, hparams.f16);
+ printf("%s: type = %d\n", __func__, model.type);
+
+ const size_t mem_required =
+ MEM_REQ_MODEL.at(model.type) +
+ MEM_REQ_ENCODE.at(model.type) +
+ MEM_REQ_ENCODE_LAYER.at(model.type) +
+ MEM_REQ_DECODE.at(model.type) +
+ MEM_REQ_DECODE_LAYER.at(model.type);
+
+ printf("%s: mem_required = %.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
+ }
+
+ // load mel filters
+ {
+ auto & filters = model.filters;
+
+ fin.read((char *) &filters.n_mel, sizeof(filters.n_mel));
+ fin.read((char *) &filters.n_fft, sizeof(filters.n_fft));
+
+ filters.data.resize(filters.n_mel * filters.n_fft);
+ fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float));
+ }
+
+ // load vocab
+ {
+ int32_t n_vocab = 0;
+ fin.read((char *) &n_vocab, sizeof(n_vocab));
+
+ //if (n_vocab != model.hparams.n_vocab) {
+ // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
+ // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
+ // return false;
+ //}
+
+ std::string word;
+ for (int i = 0; i < n_vocab; i++) {
+ uint32_t len;
+ fin.read((char *) &len, sizeof(len));
+
+ word.resize(len);
+ fin.read((char *) word.data(), len);
+
+ vocab.token_to_id[word] = i;
+ vocab.id_to_token[i] = word;
+
+ //printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str());
+ }
+
+ vocab.n_vocab = model.hparams.n_vocab;
+ if (vocab.is_multilingual()) {
+ vocab.token_eot++;
+ vocab.token_sot++;
+ vocab.token_prev++;
+ vocab.token_solm++;
+ vocab.token_beg++;
+ }
+
+ if (n_vocab < model.hparams.n_vocab) {
+ printf("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
+ for (int i = n_vocab; i < model.hparams.n_vocab; i++) {
+ if (i > vocab.token_beg) {
+ word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
+ } else if (i == vocab.token_eot) {
+ word = "[_EOT_]";
+ } else if (i == vocab.token_sot) {
+ word = "[_SOT_]";
+ } else if (i == vocab.token_prev) {
+ word = "[_PREV_]";
+ } else if (i == vocab.token_beg) {
+ word = "[_BEG_]";
+ } else {
+ word = "[_extra_token_" + std::to_string(i) + "]";
+ }
+ vocab.token_to_id[word] = i;
+ vocab.id_to_token[i] = word;
+ }
+ }
+ }
+
+ // for the big tensors, we have the option to store the data in 16-bit floats
+ // in order to save memory and also to speed up the computation
+ const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
+
+ auto & ctx = model.ctx;
+
+ size_t ctx_size = 0;
+
+ {
+ const auto & hparams = model.hparams;
+
+ const int n_vocab = hparams.n_vocab;
+
+ const int n_audio_ctx = hparams.n_audio_ctx;
+ const int n_audio_state = hparams.n_audio_state;
+ const int n_audio_layer = hparams.n_audio_layer;
+
+ const int n_text_ctx = hparams.n_text_ctx;
+ const int n_text_state = hparams.n_text_state;
+ const int n_text_layer = hparams.n_text_layer;
+
+ const int n_mels = hparams.n_mels;
+
+ // encoder
+ {
+ // TODO: F16 .. maybe not?
+ ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
+
+ ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w
+ ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_1_b
+
+ ctx_size += 3*n_audio_state*n_audio_state*ggml_type_size(wtype); // e_conv_2_w
+ ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_2_b
+
+ ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_w;
+ ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_b;
+ }
+
+ // decoder
+ {
+ // TODO: F16 .. maybe not?
+ ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
+
+ ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
+
+ ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_w;
+ ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_b;
+ }
+
+ // encoder layers
+ {
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
+
+ ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_0_w
+ ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
+
+ ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_1_w
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
+
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
+
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_q_w
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
+
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_k_w
+
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_v_w
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
+
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_ln_1_w
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
+ }
+
+ // decoder layers
+ {
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
+
+ ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_0_w
+ ctx_size += n_text_layer*( 4*n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
+
+ ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_1_w
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
+
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
+
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_q_w
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
+
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_k_w
+
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_v_w
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
+
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_ln_1_w
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
+ //
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_w
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_b
+
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_q_w
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_q_b
+
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_k_w
+
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_v_w
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_v_b
+
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_ln_1_w
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
+ }
+
+ ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // memory_k
+ ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // memory_v
+
+ ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // memory_cross_k
+ ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // memory_cross_v
+
+ ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
+
+ printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
+ }
+
+ // create the ggml context
+ {
+ struct ggml_init_params params = {
+ .mem_size = ctx_size,
+ .mem_buffer = NULL,
+ };
+
+ model.ctx = ggml_init(params);
+ if (!model.ctx) {
+ fprintf(stderr, "%s: ggml_init() failed\n", __func__);
+ return false;
+ }
+ }
+
+ // prepare memory for the weights
+ {
+ const auto & hparams = model.hparams;
+
+ const int n_vocab = hparams.n_vocab;
+
+ const int n_audio_ctx = hparams.n_audio_ctx;
+ const int n_audio_state = hparams.n_audio_state;
+ const int n_audio_layer = hparams.n_audio_layer;
+
+ const int n_text_ctx = hparams.n_text_ctx;
+ const int n_text_state = hparams.n_text_state;
+ const int n_text_layer = hparams.n_text_layer;
+
+ const int n_mels = hparams.n_mels;
+
+ model.layers_encoder.resize(n_audio_layer);
+ model.layers_decoder.resize(n_text_layer);
+
+ // encoder
+ {
+ model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
+
+ model.e_conv_1_w = ggml_new_tensor_3d(ctx, wtype, 3, n_mels, n_audio_state);
+ model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
+
+ model.e_conv_2_w = ggml_new_tensor_3d(ctx, wtype, 3, n_audio_state, n_audio_state);
+ model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
+
+ model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
+ model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
+
+ // map by name
+ model.tensors["encoder.positional_embedding"] = model.e_pe;
+
+ model.tensors["encoder.conv1.weight"] = model.e_conv_1_w;
+ model.tensors["encoder.conv1.bias"] = model.e_conv_1_b;
+
+ model.tensors["encoder.conv2.weight"] = model.e_conv_2_w;
+ model.tensors["encoder.conv2.bias"] = model.e_conv_2_b;
+
+ model.tensors["encoder.ln_post.weight"] = model.e_ln_w;
+ model.tensors["encoder.ln_post.bias"] = model.e_ln_b;
+
+ for (int i = 0; i < n_audio_layer; ++i) {
+ auto & layer = model.layers_encoder[i];
+
+ layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
+ layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
+
+ layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state);
+ layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state);
+
+ layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state);
+ layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
+
+ layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
+ layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
+
+ layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
+ layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
+
+ layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
+
+ layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
+ layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
+
+ layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
+ layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
+
+ // map by name
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
+
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
+
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
+
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
+
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
+
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
+
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
+
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
+ }
+ }
+
+ // decoder
+ {
+ model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx);
+
+ model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab);
+
+ model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+ model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+
+ // map by name
+ model.tensors["decoder.positional_embedding"] = model.d_pe;
+
+ model.tensors["decoder.token_embedding.weight"] = model.d_te;
+
+ model.tensors["decoder.ln.weight"] = model.d_ln_w;
+ model.tensors["decoder.ln.bias"] = model.d_ln_b;
+
+ for (int i = 0; i < n_text_layer; ++i) {
+ auto & layer = model.layers_decoder[i];
+
+ layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+ layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+
+ layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state);
+ layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state);
+
+ layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state);
+ layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+
+ layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+ layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+
+ layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
+ layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+
+ layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
+
+ layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
+ layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+
+ layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
+ layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+
+ layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+ layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+
+ layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
+ layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+
+ layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
+
+ layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
+ layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+
+ layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
+ layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+
+ // map by name
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
+
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
+
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
+
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
+
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
+
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
+
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
+
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
+
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w;
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b;
+
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w;
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b;
+
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w;
+
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w;
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b;
+
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w;
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b;
+ }
+ }
+ }
+
+ // key + value memory
+ {
+ const auto & hparams = model.hparams;
+
+ const int n_text_state = hparams.n_text_state;
+ const int n_text_layer = hparams.n_text_layer;
+ const int n_text_ctx = hparams.n_text_ctx;
+
+ {
+ const int n_mem = n_text_layer*n_text_ctx;
+ const int n_elements = n_text_state*n_mem;
+
+ model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
+ model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
+ }
+
+ {
+ const int n_audio_ctx = hparams.n_audio_ctx;
+
+ const int n_mem = n_text_layer*n_audio_ctx;
+ const int n_elements = n_text_state*n_mem;
+
+ model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
+ model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
+ }
+
+ const size_t memory_size =
+ ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) +
+ ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
+
+ printf("%s: memory size = %8.2f MB \n", __func__, memory_size/1024.0/1024.0);
+ }
+
+ // load weights
+ {
+ size_t total_size = 0;
+
+ while (true) {
+ int32_t n_dims;
+ int32_t length;
+ int32_t ftype;
+
+ fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+ fin.read(reinterpret_cast<char *>(&length), sizeof(length));
+ fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
+
+ if (fin.eof()) {
+ break;
+ }
+
+ int32_t nelements = 1;
+ int32_t ne[3] = { 1, 1, 1 };
+ for (int i = 0; i < n_dims; ++i) {
+ fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+ nelements *= ne[i];
+ }
+
+ std::string name(length, 0);
+ fin.read(&name[0], length);
+
+ if (model.tensors.find(name.data()) == model.tensors.end()) {
+ fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
+ return false;
+ }
+
+ auto tensor = model.tensors[name.data()];
+ if (ggml_nelements(tensor) != nelements) {
+ fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
+ return false;
+ }
+
+ if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
+ fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
+ __func__, name.data(), tensor->ne[0], tensor->ne[1], tensor->ne[2], ne[0], ne[1], ne[2]);
+ return false;
+ }
+
+ const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
+
+ if (nelements*bpe != ggml_nbytes(tensor)) {
+ fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
+ __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
+ return false;
+ }
+
+ fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
+
+ //printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
+ total_size += ggml_nbytes(tensor);
+ }
+
+ printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
+ }
+
+ fin.close();
+
+ return true;
+}
+
+// evaluate the encoder
+//
+// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
+// part of the transformer model and returns the encoded features
+//
+// - model: the model
+// - n_threads: number of threads to use
+// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
+// - mel_inp: input mel spectrogram
+// - features: output encoded features
+//
+bool whisper_encode(
+ const whisper_model & model,
+ const int n_threads,
+ const int mel_offset,
+ const whisper_mel & mel_inp,
+ std::vector<float> & features) {
+ const auto & hparams = model.hparams;
+
+ const int n_vocab = hparams.n_vocab;
+
+ const int n_ctx = hparams.n_audio_ctx;
+ const int n_state = hparams.n_audio_state;
+ const int n_head = hparams.n_audio_head;
+ const int n_layer = hparams.n_audio_layer;
+
+ const int N = n_ctx;
+
+ const int n_mels = hparams.n_mels;
+ assert(mel_inp.n_mel == n_mels);
+
+ struct ggml_init_params params;
+
+ {
+ static size_t buf_size = MEM_REQ_ENCODE.at(model.type);
+ static void * buf = malloc(buf_size);
+
+ params = {
+ .mem_size = buf_size,
+ .mem_buffer = buf,
+ };
+ }
+
+ struct ggml_context * ctx0 = ggml_init(params);
+
+ struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
+ assert(mel->type == GGML_TYPE_F32);
+ {
+ float * dst = (float *) mel->data;
+ memset(dst, 0, ggml_nbytes(mel));
+
+ const int i0 = std::min(mel_offset, mel_inp.n_len);
+ const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
+
+ for (int j = 0; j < mel_inp.n_mel; ++j) {
+ for (int i = i0; i < i1; ++i) {
+ dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
+ }
+ }
+ }
+
+ struct ggml_tensor * cur;
+
+ // convolution + gelu
+ {
+ cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
+ cur = ggml_add(ctx0,
+ ggml_repeat(ctx0,
+ model.e_conv_1_b,
+ cur),
+ cur);
+
+ cur = ggml_gelu(ctx0, cur);
+
+ cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
+ cur = ggml_add(ctx0,
+ ggml_repeat(ctx0,
+ model.e_conv_2_b,
+ cur),
+ cur);
+
+ cur = ggml_gelu(ctx0, cur);
+ }
+
+ cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
+
+ struct ggml_tensor * inpL = cur;
+
+ for (int il = 0; il < n_layer; ++il) {
+ const auto & layer = model.layers_encoder[il];
+
+ // create separate context for each layer to reduce memory usage
+
+ struct ggml_init_params paramsL;
+ {
+ static size_t buf_size = MEM_REQ_ENCODE_LAYER.at(model.type);
+ static void * buf = malloc(buf_size);
+
+ paramsL = {
+ .mem_size = buf_size,
+ .mem_buffer = buf,
+ };
+ }
+
+ struct ggml_context * ctxL = ggml_init(paramsL);
+
+ // norm
+ {
+ cur = ggml_norm(ctxL, inpL);
+
+ // cur = ln_0_w*cur + ln_0_b
+ cur = ggml_add(ctxL,
+ ggml_mul(ctxL,
+ ggml_repeat(ctxL, layer.attn_ln_0_w, cur),
+ cur),
+ ggml_repeat(ctxL, layer.attn_ln_0_b, cur));
+ }
+
+ // self-attention
+ {
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
+ layer.attn_q_w,
+ cur);
+
+ Qcur = ggml_add(ctxL,
+ ggml_repeat(ctxL,
+ layer.attn_q_b,
+ Qcur),
+ Qcur);
+
+ Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
+
+ // no bias for Key
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
+ layer.attn_k_w,
+ cur);
+
+ Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
+
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
+ layer.attn_v_w,
+ cur);
+
+ Vcur = ggml_add(ctxL,
+ ggml_repeat(ctxL,
+ layer.attn_v_b,
+ Vcur),
+ Vcur);
+
+ // ------
+
+ struct ggml_tensor * Q =
+ ggml_permute(ctxL,
+ ggml_cpy(ctxL,
+ Qcur,
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
+ 0, 2, 1, 3);
+
+ struct ggml_tensor * K =
+ ggml_permute(ctxL,
+ ggml_cpy(ctxL,
+ Kcur,
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)), // F16 !
+ 0, 2, 1, 3);
+
+ //// BLAS attempt
+ //struct ggml_tensor * KQ =
+ // ggml_mul_mat(ctxL,
+ // ggml_cpy(ctxL, K, ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, N, n_head)),
+ // ggml_cpy(ctxL, Q, ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, N, n_head)));
+
+ // K * Q
+ struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
+
+ //struct ggml_tensor * K =
+ // ggml_cpy(ctxL,
+ // ggml_permute(ctxL,
+ // ggml_reshape_3d(ctxL,
+ // Kcur,
+ // n_state/n_head, n_head, N),
+ // 1, 2, 0, 3),
+ // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, N, n_state/n_head, n_head)
+ // );
+
+ //// K * Q
+ //struct ggml_tensor * KQ = ggml_mul_mat(ctxL, ggml_transpose(ctxL, K), Q);
+
+ //struct ggml_tensor * KQ_scaled =
+ // ggml_scale(ctxL,
+ // KQ,
+ // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
+ // );
+
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ);
+
+ //struct ggml_tensor * V_trans =
+ // ggml_permute(ctxL,
+ // ggml_cpy(ctxL,
+ // Vcur,
+ // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
+ // 1, 2, 0, 3);
+
+ //struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
+
+ struct ggml_tensor * V =
+ ggml_cpy(ctxL,
+ ggml_permute(ctxL,
+ ggml_reshape_3d(ctxL,
+ Vcur,
+ n_state/n_head, n_head, N),
+ 0, 2, 1, 3),
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, N, n_head) // F16 !
+ );
+
+ struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
+
+ struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
+
+ cur = ggml_cpy(ctxL,
+ KQV_merged,
+ ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
+ }
+
+ // projection
+ {
+ cur = ggml_mul_mat(ctxL,
+ layer.attn_ln_1_w,
+ cur);
+
+ cur = ggml_add(ctxL,
+ ggml_repeat(ctxL, layer.attn_ln_1_b, cur),
+ cur);
+ }
+
+ // add the input
+ cur = ggml_add(ctxL, cur, inpL);
+
+ struct ggml_tensor * inpFF = cur;
+
+ // feed-forward network
+ {
+ // norm
+ {
+ cur = ggml_norm(ctxL, inpFF);
+
+ // cur = mlp_ln_w*cur + mlp_ln_b
+ cur = ggml_add(ctxL,
+ ggml_mul(ctxL,
+ ggml_repeat(ctxL, layer.mlp_ln_w, cur),
+ cur),
+ ggml_repeat(ctxL, layer.mlp_ln_b, cur));
+ }
+
+ // fully connected
+ cur = ggml_mul_mat(ctxL,
+ layer.mlp_0_w,
+ cur);
+
+ cur = ggml_add(ctxL,
+ ggml_repeat(ctxL, layer.mlp_0_b, cur),
+ cur);
+
+ // GELU activation
+ cur = ggml_gelu(ctxL, cur);
+
+ // projection
+ cur = ggml_mul_mat(ctxL,
+ layer.mlp_1_w,
+ cur);
+
+ cur = ggml_add(ctxL,
+ ggml_repeat(ctxL, layer.mlp_1_b, cur),
+ cur);
+ }
+
+ // output from this layer
+ struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF);
+
+ {
+ struct ggml_cgraph gf = { .n_threads = n_threads };
+
+ ggml_build_forward_expand(&gf, inpO);
+ ggml_graph_compute (ctxL, &gf);
+
+ //ggml_graph_print(&gf);
+ }
+
+ // TODO: this is a hack to have per-layer computation graphs - need to come up with something better
+ // input for next layer (inpO -> inpL)
+ memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
+ inpL->op = GGML_OP_NONE;
+ inpL->src0 = NULL;
+ inpL->src1 = NULL;
+
+ //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
+
+ ggml_free(ctxL);
+ }
+
+ cur = inpL;
+
+ // norm
+ {
+ cur = ggml_norm(ctx0, cur);
+
+ // cur = ln_f_g*cur + ln_f_b
+ cur = ggml_add(ctx0,
+ ggml_mul(ctx0,
+ ggml_repeat(ctx0, model.e_ln_w, cur),
+ cur),
+ ggml_repeat(ctx0, model.e_ln_b, cur));
+ }
+
+ // run the computation
+ {
+ struct ggml_cgraph gf = { .n_threads = n_threads };
+
+ ggml_build_forward_expand(&gf, cur);
+ ggml_graph_compute (ctx0, &gf);
+
+ //ggml_graph_print(&gf);
+ }
+
+ // cur
+ //{
+ // printf("ne0 = %d\n", cur->ne[0]);
+ // printf("ne1 = %d\n", cur->ne[1]);
+ // for (int i = 0; i < 10; ++i) {
+ // printf("%8.4f ", ((float *)(cur->data))[i]);
+ // }
+ // printf("... ");
+ // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
+ // printf("%8.4f ", ((float *)(cur->data))[i]);
+ // }
+ // printf("\n");
+ //}
+
+ // pre-compute cross-attention memory
+ {
+ struct ggml_cgraph gf = { .n_threads = n_threads };
+
+ // TODO: hack to disconnect the encoded features from the previous graph
+ cur->op = GGML_OP_NONE;
+ cur->src0 = NULL;
+ cur->src1 = NULL;
+
+ for (int il = 0; il < model.hparams.n_text_layer; ++il) {
+ auto & layer = model.layers_decoder[il];
+
+ struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
+ layer.cross_attn_k_w,
+ cur);
+
+ Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
+
+ struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
+ layer.cross_attn_v_w,
+ cur);
+
+ Vcross = ggml_add(ctx0,
+ ggml_repeat(ctx0,
+ layer.cross_attn_v_b,
+ Vcross),
+ Vcross);
+
+ struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx));
+ struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx));
+
+ ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
+ ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
+ }
+
+ ggml_graph_compute(ctx0, &gf);
+ }
+
+ ////////////////////////////////////////////////////////////////////////////
+
+ // output the features
+ assert(cur->type == GGML_TYPE_F32);
+ features.resize(cur->ne[0]*cur->ne[1]);
+ memcpy(features.data(), cur->data, features.size()*sizeof(float));
+
+ //printf("%s: used_mem = %f MB\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0);
+
+ ggml_free(ctx0);
+
+ return true;
+}
+
+// evaluate the decoder
+//
+// given text prompt + audio features -> predicts the probabilities for the next token
+//
+// - model: the model
+// - n_threads: number of threads to use
+// - n_past: prompt length
+// - prompt: text prompt
+// - logits_out: output logits
+// - probs_out: output probabilities
+//
+bool whisper_decode(
+ const whisper_model & model,
+ const int n_threads,
+ const int n_past,
+ const std::vector<whisper_vocab::id> & prompt,
+ std::vector<float> & logits_out,
+ std::vector<float> & probs_out) {
+ const auto & hparams = model.hparams;
+
+ const int n_vocab = hparams.n_vocab;
+
+ const int n_ctx = hparams.n_text_ctx;
+ const int n_state = hparams.n_text_state;
+ const int n_head = hparams.n_text_head;
+ const int n_layer = hparams.n_text_layer;
+
+ const int N = prompt.size();
+ const int M = hparams.n_audio_ctx;
+
+ struct ggml_init_params params;
+
+ {
+ static size_t buf_size = MEM_REQ_DECODE.at(model.type);
+ static void * buf = malloc(buf_size);
+
+ params = {
+ .mem_size = buf_size,
+ .mem_buffer = buf,
+ };
+ }
+
+ struct ggml_context * ctx0 = ggml_init(params);
+
+ struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+ memcpy(embd->data, prompt.data(), N*ggml_element_size(embd));
+
+ struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+ for (int i = 0; i < N; ++i) {
+ ((int32_t *) position->data)[i] = n_past + i;
+ }
+
+ // wte + wpe
+ struct ggml_tensor * cur =
+ ggml_add(ctx0,
+ ggml_get_rows(ctx0, model.d_te, embd),
+ ggml_get_rows(ctx0, model.d_pe, position));
+
+ struct ggml_tensor * inpL = cur;
+
+ for (int il = 0; il < n_layer; ++il) {
+ const auto & layer = model.layers_decoder[il];
+
+ struct ggml_init_params paramsL;
+
+ {
+ static size_t buf_size = MEM_REQ_DECODE_LAYER.at(model.type);
+ static void * buf = malloc(buf_size);
+
+ paramsL = {
+ .mem_size = buf_size,
+ .mem_buffer = buf,
+ };
+ }
+
+ struct ggml_context * ctxL = ggml_init(paramsL);
+ struct ggml_cgraph gf = { .n_threads = n_threads };
+
+ // norm
+ {
+ cur = ggml_norm(ctxL, inpL);
+
+ // cur = ln_0_w*cur + ln_0_b
+ cur = ggml_add(ctxL,
+ ggml_mul(ctxL,
+ ggml_repeat(ctxL, layer.attn_ln_0_w, cur),
+ cur),
+ ggml_repeat(ctxL, layer.attn_ln_0_b, cur));
+ }
+
+ // self-attention
+ {
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
+ layer.attn_q_w,
+ cur);
+
+ Qcur = ggml_add(ctxL,
+ ggml_repeat(ctxL,
+ layer.attn_q_b,
+ Qcur),
+ Qcur);
+
+ Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
+
+ // no bias for Key
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
+ layer.attn_k_w,
+ cur);
+
+ Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
+
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
+ layer.attn_v_w,
+ cur);
+
+ Vcur = ggml_add(ctxL,
+ ggml_repeat(ctxL,
+ layer.attn_v_b,
+ Vcur),
+ Vcur);
+
+ // store key and value to memory
+ {
+ struct ggml_tensor * k = ggml_view_1d(ctxL, model.memory_k, N*n_state, (ggml_element_size(model.memory_k)*n_state)*(il*n_ctx + n_past));
+ struct ggml_tensor * v = ggml_view_1d(ctxL, model.memory_v, N*n_state, (ggml_element_size(model.memory_v)*n_state)*(il*n_ctx + n_past));
+
+ ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k));
+ ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v));
+ }
+
+ // ------
+
+ struct ggml_tensor * Q =
+ ggml_permute(ctxL,
+ ggml_cpy(ctxL,
+ Qcur,
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
+ 0, 2, 1, 3);
+
+ struct ggml_tensor * K =
+ ggml_permute(ctxL,
+ ggml_reshape_3d(ctxL,
+ ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state),
+ n_state/n_head, n_head, n_past + N),
+ 0, 2, 1, 3);
+
+ // K * Q
+ struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
+
+ //struct ggml_tensor * KQ_scaled =
+ // ggml_scale(ctxL,
+ // KQ,
+ // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
+ // );
+
+ struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past);
+
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked);
+
+ struct ggml_tensor * V_trans =
+ ggml_permute(ctxL,
+ ggml_reshape_3d(ctxL,
+ ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state),
+ n_state/n_head, n_head, n_past + N),
+ 1, 2, 0, 3);
+
+ struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
+
+ struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
+
+ cur = ggml_cpy(ctxL,
+ KQV_merged,
+ ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
+ }
+
+ {
+ cur = ggml_mul_mat(ctxL,
+ layer.attn_ln_1_w,
+ cur);
+
+ cur = ggml_add(ctxL,
+ ggml_repeat(ctxL, layer.attn_ln_1_b, cur),
+ cur);
+ }
+
+ // add the input
+ struct ggml_tensor * inpCA = ggml_add(ctxL, cur, inpL);
+
+ // norm
+ {
+ cur = ggml_norm(ctxL, inpCA); // Note we use inpCA here
+
+ // cur = ln_0_w*cur + ln_0_b
+ cur = ggml_add(ctxL,
+ ggml_mul(ctxL,
+ ggml_repeat(ctxL, layer.cross_attn_ln_0_w, cur),
+ cur),
+ ggml_repeat(ctxL, layer.cross_attn_ln_0_b, cur));
+ }
+
+ // cross-attention
+ {
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
+ layer.cross_attn_q_w,
+ cur);
+
+ Qcur = ggml_add(ctxL,
+ ggml_repeat(ctxL,
+ layer.cross_attn_q_b,
+ Qcur),
+ Qcur);
+
+ Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
+
+ // Kcross is already scaled
+ struct ggml_tensor * Kcross =
+ ggml_reshape_3d(ctxL,
+ ggml_view_1d(ctxL, model.memory_cross_k, M*n_state, il*M*ggml_element_size(model.memory_cross_k)*n_state),
+ n_state/n_head, n_head, M);
+
+ struct ggml_tensor * Vcross =
+ ggml_reshape_3d(ctxL,
+ ggml_view_1d(ctxL, model.memory_cross_v, M*n_state, il*M*ggml_element_size(model.memory_cross_v)*n_state),
+ n_state/n_head, n_head, M);
+
+ // ------
+
+ struct ggml_tensor * Q =
+ ggml_permute(ctxL,
+ ggml_cpy(ctxL,
+ Qcur,
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
+ 0, 2, 1, 3);
+
+ struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3);
+
+ // K * Q
+ struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
+
+ //struct ggml_tensor * KQ_scaled =
+ // ggml_scale(ctxL,
+ // KQ,
+ // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
+ // );
+
+ // no masking for cross-attention
+ //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past);
+
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ);
+
+ struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3);
+
+ struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
+
+ struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
+
+ // cur = KQV_merged.contiguous().view(n_state, N)
+ cur = ggml_cpy(ctxL,
+ KQV_merged,
+ ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
+ }
+
+ // projection
+ {
+ cur = ggml_mul_mat(ctxL,
+ layer.cross_attn_ln_1_w,
+ cur);
+
+ cur = ggml_add(ctxL,
+ ggml_repeat(ctxL, layer.cross_attn_ln_1_b, cur),
+ cur);
+ }
+
+
+ // add the input
+ cur = ggml_add(ctxL, cur, inpCA);
+
+ struct ggml_tensor * inpFF = cur;
+
+ // feed-forward network
+ {
+ // norm
+ {
+ cur = ggml_norm(ctxL, inpFF);
+
+ // cur = ln_2_g*cur + ln_2_b
+ // [ 768, N]
+ cur = ggml_add(ctxL,
+ ggml_mul(ctxL,
+ ggml_repeat(ctxL, layer.mlp_ln_w, cur),
+ cur),
+ ggml_repeat(ctxL, layer.mlp_ln_b, cur));
+ }
+
+ // fully connected
+ cur = ggml_mul_mat(ctxL,
+ layer.mlp_0_w,
+ cur);
+
+ cur = ggml_add(ctxL,
+ ggml_repeat(ctxL, layer.mlp_0_b, cur),
+ cur);
+
+ // GELU activation
+ cur = ggml_gelu(ctxL, cur);
+
+ // projection
+ cur = ggml_mul_mat(ctxL,
+ layer.mlp_1_w,
+ cur);
+
+ cur = ggml_add(ctxL,
+ ggml_repeat(ctxL, layer.mlp_1_b, cur),
+ cur);
+ }
+
+ // output from this layer
+ struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF);
+
+ {
+ ggml_build_forward_expand(&gf, inpO);
+ ggml_graph_compute (ctxL, &gf);
+
+ //ggml_graph_print(&gf);
+ }
+
+ // TODO: this is a hack to have per-layer computation graphs - need to come up with something better
+ // input for next layer (inpO -> inpL)
+ memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
+ inpL->op = GGML_OP_NONE;
+ inpL->src0 = NULL;
+ inpL->src1 = NULL;
+
+ if (N > 1) {
+ //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
+ }
+
+ ggml_free(ctxL);
+ }
+
+ cur = inpL;
+
+ // norm
+ {
+ cur = ggml_norm(ctx0, cur);
+
+ cur = ggml_add(ctx0,
+ ggml_mul(ctx0,
+ ggml_repeat(ctx0, model.d_ln_w, cur),
+ cur),
+ ggml_repeat(ctx0, model.d_ln_b, cur));
+ }
+
+ struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
+
+ // logits -> probs
+ cur = ggml_dup(ctx0, logits);
+ cur = ggml_soft_max(ctx0, cur); // in-place
+
+ // run the computation
+ {
+ struct ggml_cgraph gf = { .n_threads = n_threads };
+
+ ggml_build_forward_expand(&gf, cur);
+ ggml_graph_compute (ctx0, &gf);
+ }
+
+ logits_out.resize(N*n_vocab);
+ memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
+
+ probs_out.resize(N*n_vocab);
+ memcpy(probs_out.data(), ggml_get_data(cur), sizeof(float)*N*n_vocab);
+
+ //if (N > 1) {
+ // const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N;
+ // printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token);
+ // printf("%s: max mem = %f MB\n", __func__, mem_per_token*model.hparams.n_text_ctx);
+ //}
+
+ ggml_free(ctx0);
+
+ return true;
+}
+
+// the most basic sampling scheme - select the top token
+// TODO: beam search
+// TODO: temperature
+whisper_vocab::id whisper_sample_best(
+ const whisper_vocab & vocab,
+ const float * probs,
+ double temp,
+ int offset = 0) {
+ int n_logits = vocab.id_to_token.size();
+
+ std::vector<std::pair<double, whisper_vocab::id>> probs_id;
+ probs_id.reserve(n_logits);
+
+ for (int i = offset; i < n_logits; i++) {
+ probs_id.push_back(std::make_pair(probs[i], i));
+ }
+
+ const int top_k = 10;
+
+ // find the top K tokens
+ std::partial_sort(
+ probs_id.begin(),
+ probs_id.begin() + top_k, probs_id.end(),
+ [](const std::pair<double, whisper_vocab::id> & a, const std::pair<double, whisper_vocab::id> & b) {
+ return a.first > b.first;
+ });
+
+ probs_id.resize(top_k);
+
+ //printf("\n");
+ //for (int i = 0; i < (int) probs_id.size(); i++) {
+ // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
+ //}
+
+ int res = 0;
+ while (probs_id[res].second == vocab.token_solm && res < (int) probs_id.size() - 1) {
+ res++;
+ }
+
+ return probs_id[res].second;
+}
+
+// Cooley-Tukey FFT
+// poor man's implmentation - use something better
+// input is real-valued
+// output is complex-valued
+void fft(const std::vector<float> & in, std::vector<float> & out) {
+ out.resize(in.size()*2);
+
+ int N = in.size();
+
+ if (N == 1) {
+ out[0] = in[0];
+ out[1] = 0;
+ return;
+ }
+
+ std::vector<float> even;
+ std::vector<float> odd;
+
+ for (int i = 0; i < N; i++) {
+ if (i % 2 == 0) {
+ even.push_back(in[i]);
+ } else {
+ odd.push_back(in[i]);
+ }
+ }
+
+ std::vector<float> even_fft;
+ std::vector<float> odd_fft;
+
+ fft(even, even_fft);
+ fft(odd, odd_fft);
+
+ for (int k = 0; k < N/2; k++) {
+ float theta = 2*M_PI*k/N;
+
+ float re = cos(theta);
+ float im = -sin(theta);
+
+ float re_odd = odd_fft[2*k + 0];
+ float im_odd = odd_fft[2*k + 1];
+
+ out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
+ out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
+
+ out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
+ out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
+ }
+}
+
+// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
+bool log_mel_spectrogram(
+ const std::vector<float> sf32,
+ const int sample_rate,
+ const int fft_size,
+ const int fft_step,
+ const int n_mel,
+ const int n_threads,
+ const whisper_filters & filters,
+ whisper_mel & mel) {
+ const int n_sample = sf32.size();
+ const float * samples = sf32.data();
+
+ // Hanning window
+ std::vector<float> hann;
+ hann.resize(fft_size);
+ for (int i = 0; i < fft_size; i++) {
+ hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size)));
+ }
+
+ mel.n_mel = n_mel;
+ mel.n_len = (n_sample)/fft_step;
+ mel.data.resize(mel.n_mel*mel.n_len);
+
+ const int n_fft = 1 + fft_size/2;
+
+ printf("%s: n_sample = %d, n_len = %d\n", __func__, n_sample, mel.n_len);
+ printf("%s: recording length: %f s\n", __func__, (float) n_sample/sample_rate);
+
+ std::vector<std::thread> workers(n_threads);
+ for (int iw = 0; iw < n_threads; ++iw) {
+ workers[iw] = std::thread([&](int ith) {
+ std::vector<float> fft_in;
+ fft_in.resize(fft_size);
+ for (int i = 0; i < fft_size; i++) {
+ fft_in[i] = 0.0;
+ }
+
+ std::vector<float> fft_out;
+ fft_out.resize(2*fft_size);
+
+ for (int i = ith; i < mel.n_len; i += n_threads) {
+ const int offset = i*fft_step;
+
+ // apply Hanning window
+ for (int j = 0; j < fft_size; j++) {
+ if (offset + j < n_sample) {
+ fft_in[j] = hann[j]*samples[offset + j];
+ } else {
+ fft_in[j] = 0.0;
+ }
+ }
+
+ // FFT -> mag^2
+ fft(fft_in, fft_out);
+
+ for (int j = 0; j < n_fft; j++) {
+ fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]);
+ }
+
+ // mel spectrogram
+ for (int j = 0; j < mel.n_mel; j++) {
+ double sum = 0.0;
+
+ for (int k = 0; k < n_fft; k++) {
+ sum += fft_out[k]*filters.data[j*n_fft + k];
+ }
+ if (sum < 1e-10) {
+ sum = 1e-10;
+ }
+
+ sum = log10(sum);
+
+ mel.data[j*mel.n_len + i] = sum;
+ }
+ }
+ }, iw);
+ }
+
+ for (int iw = 0; iw < n_threads; ++iw) {
+ workers[iw].join();
+ }
+
+ // clamping and normalization
+ double mmax = -1e20;
+ for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
+ if (mel.data[i] > mmax) {
+ mmax = mel.data[i];
+ }
+ }
+
+ mmax -= 8.0;
+
+ for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
+ if (mel.data[i] < mmax) {
+ mel.data[i] = mmax;
+ }
+
+ mel.data[i] = (mel.data[i] + 4.0)/4.0;
+ }
+
+ return true;
+}
+
+int main(int argc, char ** argv) {
+ const int64_t t_main_start_us = ggml_time_us();
+
+ whisper_params params;
+ params.model = "models/whisper-tiny.en/ggml-model.bin";
+
+ if (whisper_params_parse(argc, argv, params) == false) {
+ return 1;
+ }
+
+ if (params.seed < 0) {
+ params.seed = time(NULL);
+ }
+
+ // Model loading
+
+ //printf("%s: seed = %d\n", __func__, params.seed);
+
+ int64_t t_load_us = 0;
+ int64_t t_mel_us = 0;
+ int64_t t_sample_us = 0;
+ int64_t t_encode_us = 0;
+ int64_t t_decode_us = 0;
+
+ whisper_vocab vocab;
+ whisper_model model;
+
+ // load the model
+ {
+ const int64_t t_start_us = ggml_time_us();
+
+ if (!whisper_model_load(params.model, model, vocab)) {
+ fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
+ return 1;
+ }
+
+ t_load_us = ggml_time_us() - t_start_us;
+ }
+
+ // WAV input
+ std::vector<float> pcmf32;
+ {
+ drwav wav;
+ if (!drwav_init_file(&wav, params.fname_inp.c_str(), NULL)) {
+ fprintf(stderr, "%s: failed to open WAV file '%s' - check your input\n", argv[0], params.fname_inp.c_str());
+ return 2;
+ }
+
+ if (wav.channels != 1) {
+ fprintf(stderr, "%s: WAV file '%s' must be mono\n", argv[0], params.fname_inp.c_str());
+ return 3;
+ }
+
+ if (wav.sampleRate != SAMPLE_RATE) {
+ fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], params.fname_inp.c_str());
+ return 4;
+ }
+
+ if (wav.bitsPerSample != 16) {
+ fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], params.fname_inp.c_str());
+ return 5;
+ }
+
+ std::vector<int16_t> pcm16;
+ pcm16.resize(wav.totalPCMFrameCount);
+ drwav_read_pcm_frames_s16(&wav, wav.totalPCMFrameCount, pcm16.data());
+ drwav_uninit(&wav);
+
+ // convert to float
+ pcmf32.resize(pcm16.size());
+ for (size_t i = 0; i < pcm16.size(); i++) {
+ pcmf32[i] = float(pcm16[i])/32768.0f;
+ }
+ }
+
+ // compute log mel spectrogram
+ whisper_mel mel_inp;
+ {
+ const int64_t t_start_us = ggml_time_us();
+
+ log_mel_spectrogram(pcmf32, SAMPLE_RATE, N_FFT, HOP_LENGTH, N_MEL, params.n_threads, model.filters, mel_inp);
+
+ t_mel_us = ggml_time_us() - t_start_us;
+ }
+
+ std::vector<whisper_vocab::id> prompt_past = { };
+
+ // main loop
+ int seek = 0;
+ while (true) {
+ if (seek >= mel_inp.n_len) {
+ break;
+ }
+
+ // encode audio features starting at offset seek
+ std::vector<float> features;
+ {
+ const int64_t t_start_us = ggml_time_us();
+
+ if (!whisper_encode(model, params.n_threads, seek, mel_inp, features)) {
+ fprintf(stderr, "%s: failed to eval\n", __func__);
+ return 1;
+ }
+
+ t_encode_us = ggml_time_us() - t_start_us;
+ }
+
+ std::vector<float> probs;
+ std::vector<float> logits;
+
+ // SOT
+ // ref: https://github.com/openai/whisper/blob/15ab54826343c27cfaf44ce31e9c8fb63d0aa775/whisper/decoding.py#L506-L526
+ // TODO: use different initial tokens for different tasks
+ std::vector<whisper_vocab::id> prompt = { vocab.token_sot };
+
+ int n_past = 0;
+
+ if (prompt_past.size() > 0) {
+ int n_take = std::min(model.hparams.n_text_ctx/2, int(prompt_past.size()));
+
+ prompt = { vocab.token_prev };
+ prompt.insert(prompt.end(), prompt_past.end() - n_take, prompt_past.end());
+ prompt.push_back(vocab.token_sot);
+
+ prompt_past.clear();
+ prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - 1);
+ }
+
+ bool done = false;
+ int seek_delta = 100*CHUNK_SIZE;
+ whisper_vocab::id last_id = 0;
+
+ //for (int i = 0; i < prompt.size(); i++) {
+ // printf("%s: prompt[%d] = %s\n", __func__, i, vocab.id_to_token[prompt[i]].c_str());
+ //}
+
+ printf("\n");
+ for (int i = 0; i < model.hparams.n_text_ctx/2; ++i) {
+ // decode
+ if (prompt.size() > 0) {
+ const int64_t t_start_us = ggml_time_us();
+
+ if (!whisper_decode(model, params.n_threads, n_past, prompt, logits, probs)) {
+ fprintf(stderr, "%s: failed to eval\n", __func__);
+ return 1;
+ }
+
+ t_decode_us += ggml_time_us() - t_start_us;
+ }
+
+ n_past += prompt.size();
+ prompt.clear();
+
+ {
+ // sample next token
+ const float temp = 1.0; // TODO
+
+ const int n_vocab = model.hparams.n_vocab;
+
+ whisper_vocab::id id = 0;
+
+ {
+ const int64_t t_start_sample_us = ggml_time_us();
+
+ id = whisper_sample_best(vocab, probs.data() + (probs.size() - n_vocab), temp, i > params.max_tokens_per_iter ? vocab.token_beg : 0);
+
+ t_sample_us += ggml_time_us() - t_start_sample_us;
+ }
+
+ // end of text token
+ if (id == vocab.token_eot) {
+ break;
+ }
+
+ // 2 consecutive time tokens
+ if (id > vocab.token_beg && last_id > vocab.token_beg) {
+ seek_delta = 2*(id - vocab.token_beg);
+ done = true;
+ }
+ last_id = id;
+
+ // add it to the context
+ prompt.push_back(id);
+ prompt_past.push_back(id);
+ }
+
+ // display text
+ for (auto id : prompt) {
+ if (params.print_special_tokens == false && id >= vocab.token_eot) {
+ continue;
+ }
+ printf("%s", vocab.id_to_token[id].c_str());
+ }
+ fflush(stdout);
+
+ if (done) {
+ break;
+ }
+ }
+
+ seek += seek_delta;
+ }
+
+ // report timing
+ {
+ const int64_t t_main_end_us = ggml_time_us();
+
+ printf("\n\n");
+ printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
+ printf("%s: mel time = %8.2f ms\n", __func__, t_mel_us/1000.0f);
+ printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f);
+ printf("%s: encode time = %8.2f ms / %.2f ms per layer\n", __func__, t_encode_us/1000.0f, t_encode_us/1000.0f/model.hparams.n_audio_layer);
+ printf("%s: decode time = %8.2f ms\n", __func__, t_decode_us/1000.0f);
+ printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
+ }
+
+ ggml_free(model.ctx);
+
+ return 0;
+}