]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
swift : fix llama-vocab api usage (#11645)
authorJhen-Jie Hong <redacted>
Tue, 4 Feb 2025 11:15:24 +0000 (19:15 +0800)
committerGitHub <redacted>
Tue, 4 Feb 2025 11:15:24 +0000 (13:15 +0200)
* swiftui : fix vocab api usage

* batched.swift : fix vocab api usage

examples/batched.swift/Sources/main.swift
examples/llama.swiftui/llama.cpp.swift/LibLlama.swift

index 371917b2ee863187d4d5b3467cbbffa20c449f64..55c31166ca278f1e049f74db071259c1e7c12c85 100644 (file)
@@ -31,6 +31,11 @@ defer {
     llama_model_free(model)
 }
 
+guard let vocab = llama_model_get_vocab(model) else {
+    print("Failed to get vocab")
+    exit(1)
+}
+
 var tokens = tokenize(text: prompt, add_bos: true)
 
 let n_kv_req = UInt32(tokens.count) + UInt32((n_len - Int(tokens.count)) * n_parallel)
@@ -41,7 +46,7 @@ context_params.n_batch = UInt32(max(n_len, n_parallel))
 context_params.n_threads = 8
 context_params.n_threads_batch = 8
 
-let context = llama_new_context_with_model(model, context_params)
+let context = llama_init_from_model(model, context_params)
 guard context != nil else {
     print("Failed to initialize context")
     exit(1)
@@ -141,7 +146,7 @@ while n_cur <= n_len {
         let new_token_id = llama_sampler_sample(smpl, context, i_batch[i])
 
         // is it an end of stream? -> mark the stream as finished
-        if llama_vocab_is_eog(model, new_token_id) || n_cur == n_len {
+        if llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len {
             i_batch[i] = -1
             // print("")
             if n_parallel > 1 {
@@ -207,7 +212,7 @@ private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
     let utf8Count = text.utf8.count
     let n_tokens = utf8Count + (add_bos ? 1 : 0)
     let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens)
-    let tokenCount = llama_tokenize(model, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, /*special tokens*/ false)
+    let tokenCount = llama_tokenize(vocab, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, /*special tokens*/ false)
     var swiftTokens: [llama_token] = []
     for i in 0 ..< tokenCount {
         swiftTokens.append(tokens[Int(i)])
@@ -218,12 +223,12 @@ private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
 
 private func token_to_piece(token: llama_token, buffer: inout [CChar]) -> String? {
     var result = [CChar](repeating: 0, count: 8)
-    let nTokens = llama_token_to_piece(model, token, &result, Int32(result.count), 0, false)
+    let nTokens = llama_token_to_piece(vocab, token, &result, Int32(result.count), 0, false)
     if nTokens < 0 {
         let actualTokensCount = -Int(nTokens)
         result = .init(repeating: 0, count: actualTokensCount)
         let check = llama_token_to_piece(
-            model,
+            vocab,
             token,
             &result,
             Int32(result.count),
index 477c3e6f2e95bbb5a6be3f438326325ddb54ec76..ee7141a66322437a0ea95d2fa6273ab48f781cab 100644 (file)
@@ -24,6 +24,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama
 actor LlamaContext {
     private var model: OpaquePointer
     private var context: OpaquePointer
+    private var vocab: OpaquePointer
     private var sampling: UnsafeMutablePointer<llama_sampler>
     private var batch: llama_batch
     private var tokens_list: [llama_token]
@@ -47,6 +48,7 @@ actor LlamaContext {
         self.sampling = llama_sampler_chain_init(sparams)
         llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4))
         llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234))
+        vocab = llama_model_get_vocab(model)
     }
 
     deinit {
@@ -79,7 +81,7 @@ actor LlamaContext {
         ctx_params.n_threads       = Int32(n_threads)
         ctx_params.n_threads_batch = Int32(n_threads)
 
-        let context = llama_new_context_with_model(model, ctx_params)
+        let context = llama_init_from_model(model, ctx_params)
         guard let context else {
             print("Could not load context!")
             throw LlamaError.couldNotInitializeContext
@@ -151,7 +153,7 @@ actor LlamaContext {
 
         new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1)
 
-        if llama_vocab_is_eog(model, new_token_id) || n_cur == n_len {
+        if llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len {
             print("\n")
             is_done = true
             let new_token_str = String(cString: temporary_invalid_cchars + [0])
@@ -297,7 +299,7 @@ actor LlamaContext {
         let utf8Count = text.utf8.count
         let n_tokens = utf8Count + (add_bos ? 1 : 0) + 1
         let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens)
-        let tokenCount = llama_tokenize(model, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, false)
+        let tokenCount = llama_tokenize(vocab, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, false)
 
         var swiftTokens: [llama_token] = []
         for i in 0..<tokenCount {
@@ -316,7 +318,7 @@ actor LlamaContext {
         defer {
             result.deallocate()
         }
-        let nTokens = llama_token_to_piece(model, token, result, 8, 0, false)
+        let nTokens = llama_token_to_piece(vocab, token, result, 8, 0, false)
 
         if nTokens < 0 {
             let newResult = UnsafeMutablePointer<Int8>.allocate(capacity: Int(-nTokens))
@@ -324,7 +326,7 @@ actor LlamaContext {
             defer {
                 newResult.deallocate()
             }
-            let nNewTokens = llama_token_to_piece(model, token, newResult, -nTokens, 0, false)
+            let nNewTokens = llama_token_to_piece(vocab, token, newResult, -nTokens, 0, false)
             let bufferPointer = UnsafeBufferPointer(start: newResult, count: Int(nNewTokens))
             return Array(bufferPointer)
         } else {