jlong context_pointer,
jlong batch_pointer,
jstring jtext,
+ jboolean format_chat,
jint n_len
) {
const auto context = reinterpret_cast<llama_context *>(context_pointer);
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
- const auto tokens_list = common_tokenize(context, text, 1);
+ bool parse_special = (format_chat == JNI_TRUE);
+ const auto tokens_list = common_tokenize(context, text, true, parse_special);
auto n_ctx = llama_n_ctx(context);
auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
}
for (auto id : tokens_list) {
- LOGi("%s", common_token_to_piece(context, id).c_str());
+ LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id);
}
common_batch_clear(*batch);
context: Long,
batch: Long,
text: String,
+ formatChat: Boolean,
nLen: Int
): Int
}
}
- fun send(message: String): Flow<String> = flow {
+ fun send(message: String, formatChat: Boolean = false): Flow<String> = flow {
when (val state = threadLocalState.get()) {
is State.Loaded -> {
- val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
+ val ncur = IntVar(completion_init(state.context, state.batch, message, formatChat, nlen))
while (ncur.value <= nlen) {
val str = completion_loop(state.context, state.batch, state.sampler, nlen, ncur)
if (str == null) {