#include "llama-vocab.h"
#include "llama-grammar.h"
+#include <array>
#include <algorithm>
#include <cassert>
#include <cfloat>
auto * ctx = new llama_sampler_grammar;
if (grammar_str != nullptr && grammar_str[0] != '\0') {
+ std::string trigger_pattern;
+ llama_grammar * grammar = nullptr;
// TODO: remove trigger_words support.
if (trigger_words != nullptr && num_trigger_words > 0) {
GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0);
- std::string trigger_pattern("[\\s\\S]*?(");
+ trigger_pattern = "[\\s\\S]*?(";
for (size_t i = 0; i < num_trigger_words; ++i) {
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
if (i > 0) {
trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0");
}
trigger_pattern += ")[\\s\\S]*";
- const auto * trigger_pattern_c = trigger_pattern.c_str();
- trigger_patterns = &trigger_pattern_c;
- num_trigger_patterns = 1;
+
+ std::array<const char *, 1> tmp_trigger_patterns = { trigger_pattern.c_str() };
+ grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, tmp_trigger_patterns.data(), tmp_trigger_patterns.size(), trigger_tokens, num_trigger_tokens);
+ } else {
+ grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens);
}
*ctx = {
/* .vocab = */ vocab,
/* .grammar_str = */ grammar_str,
/* .grammar_root = */ grammar_root,
- /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
+ /* .grammar = */ grammar,
};
if (!ctx->grammar) {
delete ctx;