compilade [Wed, 8 May 2024 22:16:38 +0000 (18:16 -0400)]
convert-hf : save memory with lazy evaluation (#7075)
* convert-hf : begin refactoring write_tensor
* convert : upgrade to sentencepiece v0.2.0
* convert-hf : remove unused n_dims in extra_*_tensors
* convert-hf : simplify MoE weights stacking
* convert-hf : flake8 linter doesn't like semicolons
* convert-hf : allow unusual model part names
For example, loading `model-00001-of-00001.safetensors` now works.
* convert-hf : fix stacking MoE expert tensors
`torch.stack` and `torch.cat` don't do the same thing.
* convert-hf : fix Mamba conversion
Tested to work even with a SentencePiece-based tokenizer.
* convert : use a string for the SentencePiece tokenizer path
* convert-hf : display tensor shape
* convert-hf : convert norms to f32 by default
* convert-hf : sort model part names
`os.listdir` is said to list files in arbitrary order.
Sorting the file names should let "model-00009-of-00042.safetensors"
be loaded before "model-00010-of-00042.safetensors".
* convert-hf : use an ABC for Model again
It seems Protocol can't be used as a statically type-checked ABC,
because its subclasses also can't be instantiated. (why did it seem to work?)
At least there's still a way to throw an error when forgetting to define
the `model_arch` property of any registered Model subclasses.
* convert-hf : use a plain class for Model, and forbid direct instantiation
There are no abstract methods used anyway,
so using ABC isn't really necessary.
* convert-hf : more consistent formatting of cmdline args
* convert-hf : align the message logged for converted tensors
* convert-hf : fix Refact conversion
* convert-hf : save memory with lazy evaluation
* convert-hf : flake8 doesn't like lowercase L as a variable name
* convert-hf : remove einops requirement for InternLM2
* convert-hf : faster model parts loading
Instead of pre-loading them all into a dict, iterate on the tensors
in the model parts progressively as needed in Model.write_tensors
Conversion for some architectures relies on checking for the presence
of specific tensor names, so for multi-part models, the weight map is read
from the relevant json file to quickly get these names up-front.
* convert-hf : minor changes for consistency
* gguf-py : add tqdm as a dependency
It's small, and used for a progress bar
in GGUFWriter.write_tensors_to_file
* Fall back if graph capture fails and address other comments
* - renamed GGML_ALLOW_CUDA_GRAPHS to GGML_CUDA_USE_GRAPHS
- rename env variable to disable CUDA graphs to GGML_CUDA_DISABLE_GRAPHS
- updated Makefile build to enable CUDA graphs
- removed graph capture failure checking in ggml_cuda_error
using a global variable to track this is not thread safe, but I am also not safistied with checking an error by string
if this is necessary to workaround some issues with graph capture with eg. cuBLAS, we can pass the ggml_backend_cuda_context to the error checking macro and store the result in the context
- fixed several resource leaks
- fixed issue with zero node graphs
- changed fixed size arrays to vectors
- removed the count of number of evaluations before start capturing, and instead changed the capture mode to relaxed
- removed the check for multiple devices so that it is still possible to use a single device, instead checks for split buffers to disable cuda graphs with -sm row
- changed the op for checking batch size to GGML_OP_ADD, should be more reliable than GGML_OP_SOFT_MAX
- code style fixes
- things to look into
- VRAM usage of the cudaGraphExec_t, if it is significant we may need to make it optional
- possibility of using cudaStreamBeginCaptureToGraph to keep track of which ggml graph nodes correspond to which cuda graph nodes
This encoding has the same number of exponent bits as float32. That
makes conversion relatively straightforward, even in the absence of
hardware support. For example, converting brain16 to binary32 means
simply shifting 16 bits to the left.
The issue is that converting bf16 to fp16 can result in information
loss. Only 13% of bf16 numbers can be precisely represented in fp16
which in practice ends up being 99.71% of Mistral 7b v0.2's weights
however there is currently no way other than fp32 to get the others
This change fixes that, by adding a bf16 data type to GGML. Support
for CPU inference has been implemented along with optimizations for
the AVX2, AVX512, and AVX512BF16 ISAs. Perplexity on Mistral 7b 0.2
improves somewhere around -0.0024 to -0.0046 compared to using fp16
* Remove GGML code that's not needed
* Minimize the GGML API surface area for BF16
* Remove bf16 luts
* Make the GGML header look nicer
* Fix documentation
* Apply ggerganov's fixes for test-backend-ops
* Add BF16 code for new ggml_validate_row_data() function
jukofyork [Wed, 8 May 2024 00:24:16 +0000 (01:24 +0100)]
Fixed save_imatrix to match old behaviour for MoE (#7099)
* Fixed save_imatrix to match old behaviour for MoE
This fix is simple and clear, but unnecessarily doubles the memory overhead..
* Fixed missing idx variable
* Unconditionally increment ncall
Co-authored-by: slaren <redacted>
* Fixed 2 bugs in save_imatrix()
- Fixed segfault bug because the counts vector needed to be created.
- Fixed pre-existing bug didn't actually add to the counts for "--combine" option.
William Tambellini [Mon, 6 May 2024 18:12:14 +0000 (11:12 -0700)]
Add an option to build without CUDA VMM (#7067)
Add an option to build ggml cuda without CUDA VMM
resolves
https://github.com/ggerganov/llama.cpp/issues/6889
https://forums.developer.nvidia.com/t/potential-nvshmem-allocated-memory-performance-issue/275416/4
maor-ps [Sat, 4 May 2024 09:06:40 +0000 (12:06 +0300)]
If first token generated from the server is the stop word the server will crash (#7038)
This will reproduce the issue in llama13b
{
'prompt': 'Q: hello world \nA: ',
'stop': ['\n'],
'temperature': 0.0,
'n_predict': 10,
'cache_prompt': True,
'n_probs': 10
}
Daniel Bevenius [Fri, 3 May 2024 13:24:30 +0000 (15:24 +0200)]
llama : rename ctx to user_data in progress_callback (#7045)
* llama : rename ctx to user_data in progress_callback
This commit renames the `ctx` parameter to `user_data` in the
`llama_progress_callback` typedef.
The motivation for this is that other callbacks use `user_data` or
`data`, and using `ctx` in this case might be confusing as it could be
confused with `llama_context`.
Georgi Gerganov [Tue, 30 Apr 2024 09:16:08 +0000 (12:16 +0300)]
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (#6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <redacted> Co-authored-by: Pierrick HYMBERT <redacted>
* Cleaning up integration tests to share code between tests and make it simpler to add new tests.
* Add tests around quantifiers to ensure both matching and non-matching compliance.
* Add slightly more complex grammar with quantifiers to test references with quantifiers.
* Fixing build when C++17 is not present.
* Separating test calls to give more helpful stack traces on failure. Adding verbose messages to give visibility for what is being tested.
* Adding quotes around strings to explicitly show whitespace
* Removing trailing whitespace.
* Implementing suggestions from @ochafik -- grammars and test strings now print and flush before tests to aid in debugging segfaults and whatnot.
* Cleaning up forgotten symbols. Modifying simple test to use test harness. Added comments for more verbose descriptions of what each test is accomplishing.
* Unicode symbol modifications to hopefully make log easier to parse visually.
Daniel Bevenius [Thu, 25 Apr 2024 12:38:14 +0000 (14:38 +0200)]
clip : rename lerp function to avoid conflict (#6894)
This commit renamesthe lerp (linear interpolation) function in clip.cpp
to avoid a conflict with the lerp function in the <cmath> standard C++
library when using c++20.
The motivation for this change is to enable projects that use c++20 to
be able to compile clip.cpp without having to resort to patching it. The
lerp function was added to cmath in version C++20 (202002L) and is why
this is not causing any issue at the moment as C++11/C++17 is currently
used by llama.cpp.
I realize that llama.cpp uses either C++11 (or C++17 in the case for
SYCL) but wanted to ask if this would be an acceptable change just the
same.