]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama.swiftui : update models layout (#4826)
authorZay <redacted>
Fri, 12 Jan 2024 12:48:00 +0000 (05:48 -0700)
committerGitHub <redacted>
Fri, 12 Jan 2024 12:48:00 +0000 (14:48 +0200)
* Updated Models Layout

- Added a models drawer
- Added downloading directly from Hugging Face
- Load custom models from local folder
- Delete models by swiping left

* trimmed trailing white space

* Updated Models Layout

examples/llama.swiftui/llama.swiftui.xcodeproj/project.pbxproj
examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift
examples/llama.swiftui/llama.swiftui/UI/ContentView.swift
examples/llama.swiftui/llama.swiftui/UI/DownloadButton.swift
examples/llama.swiftui/llama.swiftui/UI/InputButton.swift [new file with mode: 0644]

index a8848a49fce6dca58ef88885d3272dad0693c1f1..3950b9e9df843118dbde9944cd394693b08998e3 100644 (file)
@@ -8,6 +8,7 @@
 
 /* Begin PBXBuildFile section */
                549479CB2AC9E16000E0F78B /* Metal.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 549479CA2AC9E16000E0F78B /* Metal.framework */; };
+               79E1D9CD2B4CD16E005F8E46 /* InputButton.swift in Sources */ = {isa = PBXBuildFile; fileRef = 79E1D9CC2B4CD16E005F8E46 /* InputButton.swift */; };
                7FA3D2B32B2EA2F600543F92 /* DownloadButton.swift in Sources */ = {isa = PBXBuildFile; fileRef = 7FA3D2B22B2EA2F600543F92 /* DownloadButton.swift */; };
                8A1C83772AC328BD0096AF73 /* llama_swiftuiApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8A1C83762AC328BD0096AF73 /* llama_swiftuiApp.swift */; };
                8A1C83792AC328BD0096AF73 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8A1C83782AC328BD0096AF73 /* ContentView.swift */; };
@@ -22,6 +23,7 @@
 
 /* Begin PBXFileReference section */
                549479CA2AC9E16000E0F78B /* Metal.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Metal.framework; path = System/Library/Frameworks/Metal.framework; sourceTree = SDKROOT; };
+               79E1D9CC2B4CD16E005F8E46 /* InputButton.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = InputButton.swift; sourceTree = "<group>"; };
                7FA3D2B22B2EA2F600543F92 /* DownloadButton.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = DownloadButton.swift; sourceTree = "<group>"; };
                8A1C83732AC328BD0096AF73 /* llama.swiftui.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = llama.swiftui.app; sourceTree = BUILT_PRODUCTS_DIR; };
                8A1C83762AC328BD0096AF73 /* llama_swiftuiApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = llama_swiftuiApp.swift; sourceTree = "<group>"; };
                                7FA3D2B22B2EA2F600543F92 /* DownloadButton.swift */,
                                8A1C83782AC328BD0096AF73 /* ContentView.swift */,
                                F1FE20E12B465EC900B45541 /* LoadCustomButton.swift */,
+                               79E1D9CC2B4CD16E005F8E46 /* InputButton.swift */,
                        );
                        path = UI;
                        sourceTree = "<group>";
                                8A1C83792AC328BD0096AF73 /* ContentView.swift in Sources */,
                                8A1C83772AC328BD0096AF73 /* llama_swiftuiApp.swift in Sources */,
                                7FA3D2B32B2EA2F600543F92 /* DownloadButton.swift in Sources */,
+                               79E1D9CD2B4CD16E005F8E46 /* InputButton.swift in Sources */,
                        );
                        runOnlyForDeploymentPostprocessing = 0;
                };
                                CLANG_ENABLE_MODULES = YES;
                                CODE_SIGN_STYLE = Automatic;
                                CURRENT_PROJECT_VERSION = 1;
-                               DEVELOPMENT_TEAM = STLSG3FG8Q;
+                               DEVELOPMENT_TEAM = K5UQJPP73A;
                                ENABLE_PREVIEWS = YES;
                                GENERATE_INFOPLIST_FILE = YES;
                                INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
                                CLANG_ENABLE_MODULES = YES;
                                CODE_SIGN_STYLE = Automatic;
                                CURRENT_PROJECT_VERSION = 1;
-                               DEVELOPMENT_TEAM = STLSG3FG8Q;
+                               DEVELOPMENT_TEAM = K5UQJPP73A;
                                ENABLE_PREVIEWS = YES;
                                GENERATE_INFOPLIST_FILE = YES;
                                INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
index 17cb5b9dde942bac7443baa01dbda6c36d88a66c..5bde1891727ce5d4cdf0c0eaa4561568149bc339 100644 (file)
@@ -1,9 +1,19 @@
 import Foundation
 
+struct Model: Identifiable {
+    var id = UUID()
+    var name: String
+    var url: String
+    var filename: String
+    var status: String?
+}
+
 @MainActor
 class LlamaState: ObservableObject {
     @Published var messageLog = ""
     @Published var cacheCleared = false
+    @Published var downloadedModels: [Model] = []
+    @Published var undownloadedModels: [Model] = []
     let NS_PER_S = 1_000_000_000.0
 
     private var llamaContext: LlamaContext?
@@ -13,23 +23,102 @@ class LlamaState: ObservableObject {
     }
 
     init() {
+        loadModelsFromDisk()
+        loadDefaultModels()
+    }
+
+    private func loadModelsFromDisk() {
+        do {
+            let documentsURL = getDocumentsDirectory()
+            let modelURLs = try FileManager.default.contentsOfDirectory(at: documentsURL, includingPropertiesForKeys: nil, options: [.skipsHiddenFiles, .skipsSubdirectoryDescendants])
+            for modelURL in modelURLs {
+                let modelName = modelURL.deletingPathExtension().lastPathComponent
+                downloadedModels.append(Model(name: modelName, url: "", filename: modelURL.lastPathComponent, status: "downloaded"))
+            }
+        } catch {
+            print("Error loading models from disk: \(error)")
+        }
+    }
+
+    private func loadDefaultModels() {
         do {
             try loadModel(modelUrl: defaultModelUrl)
         } catch {
             messageLog += "Error!\n"
         }
+
+        for model in defaultModels {
+            let fileURL = getDocumentsDirectory().appendingPathComponent(model.filename)
+            if FileManager.default.fileExists(atPath: fileURL.path) {
+
+            } else {
+                var undownloadedModel = model
+                undownloadedModel.status = "download"
+                undownloadedModels.append(undownloadedModel)
+            }
+        }
     }
 
+    func getDocumentsDirectory() -> URL {
+        let paths = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)
+        return paths[0]
+    }
+    private let defaultModels: [Model] = [
+        Model(name: "TinyLlama-1.1B (Q4_0, 0.6 GiB)",url: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_0.gguf?download=true",filename: "tinyllama-1.1b-1t-openorca.Q4_0.gguf", status: "download"),
+        Model(
+            name: "TinyLlama-1.1B Chat (Q8_0, 1.1 GiB)",
+            url: "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q8_0.gguf?download=true",
+            filename: "tinyllama-1.1b-chat-v1.0.Q8_0.gguf", status: "download"
+        ),
+
+        Model(
+            name: "TinyLlama-1.1B (F16, 2.2 GiB)",
+            url: "https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf?download=true",
+            filename: "tinyllama-1.1b-f16.gguf", status: "download"
+        ),
+
+        Model(
+            name: "Phi-2.7B (Q4_0, 1.6 GiB)",
+            url: "https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf?download=true",
+            filename: "phi-2-q4_0.gguf", status: "download"
+        ),
+
+        Model(
+            name: "Phi-2.7B (Q8_0, 2.8 GiB)",
+            url: "https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q8_0.gguf?download=true",
+            filename: "phi-2-q8_0.gguf", status: "download"
+        ),
+
+        Model(
+            name: "Mistral-7B-v0.1 (Q4_0, 3.8 GiB)",
+            url: "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_0.gguf?download=true",
+            filename: "mistral-7b-v0.1.Q4_0.gguf", status: "download"
+        ),
+        Model(
+            name: "OpenHermes-2.5-Mistral-7B (Q3_K_M, 3.52 GiB)",
+            url: "https://huggingface.co/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF/resolve/main/openhermes-2.5-mistral-7b.Q3_K_M.gguf?download=true",
+            filename: "openhermes-2.5-mistral-7b.Q3_K_M.gguf", status: "download"
+        )
+    ]
     func loadModel(modelUrl: URL?) throws {
         if let modelUrl {
             messageLog += "Loading model...\n"
             llamaContext = try LlamaContext.create_context(path: modelUrl.path())
             messageLog += "Loaded model \(modelUrl.lastPathComponent)\n"
+
+            // Assuming that the model is successfully loaded, update the downloaded models
+            updateDownloadedModels(modelName: modelUrl.lastPathComponent, status: "downloaded")
         } else {
             messageLog += "Load a model from the list below\n"
         }
     }
 
+
+    private func updateDownloadedModels(modelName: String, status: String) {
+        undownloadedModels.removeAll { $0.name == modelName }
+    }
+
+
     func complete(text: String) async {
         guard let llamaContext else {
             return
index 7c81ea256ffd7eec6d9f7616d66d1ad5d197eba5..30c2dc4310210adf99ec7a773ad0a5e82cda1641 100644 (file)
@@ -2,115 +2,57 @@ import SwiftUI
 
 struct ContentView: View {
     @StateObject var llamaState = LlamaState()
-
     @State private var multiLineText = ""
-
-    private static func cleanupModelCaches() {
-        // Delete all models (*.gguf)
-        let fileManager = FileManager.default
-        let documentsUrl =  FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)[0]
-        do {
-            let fileURLs = try fileManager.contentsOfDirectory(at: documentsUrl, includingPropertiesForKeys: nil)
-            for fileURL in fileURLs {
-                if fileURL.pathExtension == "gguf" {
-                    try fileManager.removeItem(at: fileURL)
-                }
-            }
-        } catch {
-            print("Error while enumerating files \(documentsUrl.path): \(error.localizedDescription)")
-        }
-    }
+    @State private var showingHelp = false    // To track if Help Sheet should be shown
 
     var body: some View {
-        VStack {
-            ScrollView(.vertical, showsIndicators: true) {
-                Text(llamaState.messageLog)
-                .font(.system(size: 12))
-                .frame(maxWidth: .infinity, alignment: .leading)
-                .padding()
-                .onTapGesture {
-                    UIApplication.shared.sendAction(#selector(UIResponder.resignFirstResponder), to: nil, from: nil, for: nil)
+        NavigationView {
+            VStack {
+                ScrollView(.vertical, showsIndicators: true) {
+                    Text(llamaState.messageLog)
+                        .font(.system(size: 12))
+                        .frame(maxWidth: .infinity, alignment: .leading)
+                        .padding()
+                        .onTapGesture {
+                            UIApplication.shared.sendAction(#selector(UIResponder.resignFirstResponder), to: nil, from: nil, for: nil)
+                        }
                 }
-            }
 
-            TextEditor(text: $multiLineText)
-                .frame(height: 80)
-                .padding()
-                .border(Color.gray, width: 0.5)
+                TextEditor(text: $multiLineText)
+                    .frame(height: 80)
+                    .padding()
+                    .border(Color.gray, width: 0.5)
 
-            HStack {
-                Button("Send") {
-                    sendText()
-                }
+                HStack {
+                    Button("Send") {
+                        sendText()
+                    }
 
-                Button("Bench") {
-                    bench()
-                }
+                    Button("Bench") {
+                        bench()
+                    }
 
-                Button("Clear") {
-                    clear()
-                }
+                    Button("Clear") {
+                        clear()
+                    }
 
-                Button("Copy") {
-                    UIPasteboard.general.string = llamaState.messageLog
+                    Button("Copy") {
+                        UIPasteboard.general.string = llamaState.messageLog
+                    }
                 }
-            }.buttonStyle(.bordered)
-
-            VStack(alignment: .leading) {
-                DownloadButton(
-                    llamaState: llamaState,
-                    modelName: "TinyLlama-1.1B (Q4_0, 0.6 GiB)",
-                    modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_0.gguf?download=true",
-                    filename: "tinyllama-1.1b-1t-openorca.Q4_0.gguf"
-                )
-
-                DownloadButton(
-                    llamaState: llamaState,
-                    modelName: "TinyLlama-1.1B (Q8_0, 1.1 GiB)",
-                    modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q8_0.gguf?download=true",
-                    filename: "tinyllama-1.1b-1t-openorca.Q8_0.gguf"
-                )
-
-                DownloadButton(
-                    llamaState: llamaState,
-                    modelName: "TinyLlama-1.1B (F16, 2.2 GiB)",
-                    modelUrl: "https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf?download=true",
-                    filename: "tinyllama-1.1b-f16.gguf"
-                )
-
-                DownloadButton(
-                    llamaState: llamaState,
-                    modelName: "Phi-2.7B (Q4_0, 1.6 GiB)",
-                    modelUrl: "https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf?download=true",
-                    filename: "phi-2-q4_0.gguf"
-                )
-
-                DownloadButton(
-                    llamaState: llamaState,
-                    modelName: "Phi-2.7B (Q8_0, 2.8 GiB)",
-                    modelUrl: "https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q8_0.gguf?download=true",
-                    filename: "phi-2-q8_0.gguf"
-                )
-
-                DownloadButton(
-                    llamaState: llamaState,
-                    modelName: "Mistral-7B-v0.1 (Q4_0, 3.8 GiB)",
-                    modelUrl: "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_0.gguf?download=true",
-                    filename: "mistral-7b-v0.1.Q4_0.gguf"
-                )
-
-                Button("Clear downloaded models") {
-                    ContentView.cleanupModelCaches()
-                    llamaState.cacheCleared = true
+                .buttonStyle(.bordered)
+                .padding()
+
+                NavigationLink(destination: DrawerView(llamaState: llamaState)) {
+                    Text("View Models")
                 }
+                .padding()
 
-                LoadCustomButton(llamaState: llamaState)
             }
-            .padding(.top, 4)
-            .font(.system(size: 12))
-            .frame(maxWidth: .infinity, alignment: .leading)
+            .padding()
+            .navigationBarTitle("Model Settings", displayMode: .inline)
+
         }
-        .padding()
     }
 
     func sendText() {
@@ -131,8 +73,73 @@ struct ContentView: View {
             await llamaState.clear()
         }
     }
+    struct DrawerView: View {
+
+        @ObservedObject var llamaState: LlamaState
+        @State private var showingHelp = false
+        func delete(at offsets: IndexSet) {
+            offsets.forEach { offset in
+                let model = llamaState.downloadedModels[offset]
+                let fileURL = getDocumentsDirectory().appendingPathComponent(model.filename)
+                do {
+                    try FileManager.default.removeItem(at: fileURL)
+                } catch {
+                    print("Error deleting file: \(error)")
+                }
+            }
+
+            // Remove models from downloadedModels array
+            llamaState.downloadedModels.remove(atOffsets: offsets)
+        }
+
+        func getDocumentsDirectory() -> URL {
+            let paths = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)
+            return paths[0]
+        }
+        var body: some View {
+            List {
+                Section(header: Text("Download Models From Hugging Face")) {
+                    HStack {
+                        InputButton(llamaState: llamaState)
+                    }
+                }
+                Section(header: Text("Downloaded Models")) {
+                    ForEach(llamaState.downloadedModels) { model in
+                        DownloadButton(llamaState: llamaState, modelName: model.name, modelUrl: model.url, filename: model.filename)
+                    }
+                    .onDelete(perform: delete)
+                }
+                Section(header: Text("Default Models")) {
+                    ForEach(llamaState.undownloadedModels) { model in
+                        DownloadButton(llamaState: llamaState, modelName: model.name, modelUrl: model.url, filename: model.filename)
+                    }
+                }
+
+            }
+            .listStyle(GroupedListStyle())
+            .navigationBarTitle("Model Settings", displayMode: .inline).toolbar {
+                ToolbarItem(placement: .navigationBarTrailing) {
+                    Button("Help") {
+                        showingHelp = true
+                    }
+                }
+            }.sheet(isPresented: $showingHelp) {    // Sheet for help modal
+                VStack(alignment: .leading) {
+                    VStack(alignment: .leading) {
+                        Text("1. Make sure the model is in GGUF Format")
+                               .padding()
+                        Text("2. Copy the download link of the quantized model")
+                               .padding()
+                    }
+                    Spacer()
+                   }
+            }
+        }
+    }
 }
 
-//#Preview {
-//    ContentView()
-//}
+struct ContentView_Previews: PreviewProvider {
+    static var previews: some View {
+        ContentView()
+    }
+}
index c9f322ca14e72581a7e2abb0604f70523db4bd74..4584d6eaa3d3265ef5917a1566ff2ea08a5b7048 100644 (file)
@@ -53,6 +53,8 @@ struct DownloadButton: View {
 
                     llamaState.cacheCleared = false
 
+                    let model = Model(name: modelName, url: modelUrl, filename: filename, status: "downloaded")
+                    llamaState.downloadedModels.append(model)
                     status = "downloaded"
                 }
             } catch let err {
diff --git a/examples/llama.swiftui/llama.swiftui/UI/InputButton.swift b/examples/llama.swiftui/llama.swiftui/UI/InputButton.swift
new file mode 100644 (file)
index 0000000..c5ffbad
--- /dev/null
@@ -0,0 +1,131 @@
+import SwiftUI
+
+struct InputButton: View {
+    @ObservedObject var llamaState: LlamaState
+    @State private var inputLink: String = ""
+    @State private var status: String = "download"
+    @State private var filename: String = ""
+
+    @State private var downloadTask: URLSessionDownloadTask?
+    @State private var progress = 0.0
+    @State private var observation: NSKeyValueObservation?
+
+    private static func extractModelInfo(from link: String) -> (modelName: String, filename: String)? {
+        guard let url = URL(string: link),
+              let lastPathComponent = url.lastPathComponent.components(separatedBy: ".").first,
+              let modelName = lastPathComponent.components(separatedBy: "-").dropLast().joined(separator: "-").removingPercentEncoding,
+              let filename = lastPathComponent.removingPercentEncoding else {
+            return nil
+        }
+
+        return (modelName, filename)
+    }
+
+    private static func getFileURL(filename: String) -> URL {
+        FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)[0].appendingPathComponent(filename)
+    }
+
+    private func download() {
+        guard let extractedInfo = InputButton.extractModelInfo(from: inputLink) else {
+            // Handle invalid link or extraction failure
+            return
+        }
+
+        let (modelName, filename) = extractedInfo
+        self.filename = filename  // Set the state variable
+
+        status = "downloading"
+        print("Downloading model \(modelName) from \(inputLink)")
+        guard let url = URL(string: inputLink) else { return }
+        let fileURL = InputButton.getFileURL(filename: filename)
+
+        downloadTask = URLSession.shared.downloadTask(with: url) { temporaryURL, response, error in
+            if let error = error {
+                print("Error: \(error.localizedDescription)")
+                return
+            }
+
+            guard let response = response as? HTTPURLResponse, (200...299).contains(response.statusCode) else {
+                print("Server error!")
+                return
+            }
+
+            do {
+                if let temporaryURL = temporaryURL {
+                    try FileManager.default.copyItem(at: temporaryURL, to: fileURL)
+                    print("Writing to \(filename) completed")
+
+                    llamaState.cacheCleared = false
+
+                    let model = Model(name: modelName, url: self.inputLink, filename: filename, status: "downloaded")
+                    llamaState.downloadedModels.append(model)
+                    status = "downloaded"
+                }
+            } catch let err {
+                print("Error: \(err.localizedDescription)")
+            }
+        }
+
+        observation = downloadTask?.progress.observe(\.fractionCompleted) { progress, _ in
+            self.progress = progress.fractionCompleted
+        }
+
+        downloadTask?.resume()
+    }
+
+    var body: some View {
+        VStack {
+            HStack {
+                TextField("Paste Quantized Download Link", text: $inputLink)
+                    .textFieldStyle(RoundedBorderTextFieldStyle())
+
+                Button(action: {
+                    downloadTask?.cancel()
+                    status = "download"
+                }) {
+                    Text("Cancel")
+                }
+            }
+
+            if status == "download" {
+                Button(action: download) {
+                    Text("Download Custom Model")
+                }
+            } else if status == "downloading" {
+                Button(action: {
+                    downloadTask?.cancel()
+                    status = "download"
+                }) {
+                    Text("Downloading \(Int(progress * 100))%")
+                }
+            } else if status == "downloaded" {
+                Button(action: {
+                    let fileURL = InputButton.getFileURL(filename: self.filename)
+                    if !FileManager.default.fileExists(atPath: fileURL.path) {
+                        download()
+                        return
+                    }
+                    do {
+                        try llamaState.loadModel(modelUrl: fileURL)
+                    } catch let err {
+                        print("Error: \(err.localizedDescription)")
+                    }
+                }) {
+                    Text("Load Custom Model")
+                }
+            } else {
+                Text("Unknown status")
+            }
+        }
+        .onDisappear() {
+            downloadTask?.cancel()
+        }
+        .onChange(of: llamaState.cacheCleared) { newValue in
+            if newValue {
+                downloadTask?.cancel()
+                let fileURL = InputButton.getFileURL(filename: self.filename)
+                status = FileManager.default.fileExists(atPath: fileURL.path) ? "downloaded" : "download"
+            }
+        }
+    }
+}