]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : support input embeddings directly (#1910)
authorningshanwutuobang <redacted>
Wed, 28 Jun 2023 15:53:37 +0000 (23:53 +0800)
committerGitHub <redacted>
Wed, 28 Jun 2023 15:53:37 +0000 (18:53 +0300)
* add interface for float input

* fixed inpL shape and type

* add examples of input floats

* add test example for embd input

* fixed sampling

* add free for context

* fixed add end condition for generating

* add examples for llava.py

* add READMD for llava.py

* add READMD for llava.py

* add example of PandaGPT

* refactor the interface and fixed the styles

* add cmake build for embd-input

* add cmake build for embd-input

* Add MiniGPT-4 example

* change the order of the args of llama_eval_internal

* fix ci error

16 files changed:
.gitignore
Makefile
convert-lora-to-ggml.py
examples/CMakeLists.txt
examples/embd-input/.gitignore [new file with mode: 0644]
examples/embd-input/CMakeLists.txt [new file with mode: 0644]
examples/embd-input/README.md [new file with mode: 0644]
examples/embd-input/embd-input-lib.cpp [new file with mode: 0644]
examples/embd-input/embd-input-test.cpp [new file with mode: 0644]
examples/embd-input/embd-input.h [new file with mode: 0644]
examples/embd-input/embd_input.py [new file with mode: 0644]
examples/embd-input/llava.py [new file with mode: 0644]
examples/embd-input/minigpt4.py [new file with mode: 0644]
examples/embd-input/panda_gpt.py [new file with mode: 0644]
llama.cpp
llama.h

index e7bfd52e3d63ceb1309f95cc36f261eb52d87cd4..4fccec31b8114e88737f3398b4eb57908bbdd6c8 100644 (file)
@@ -1,5 +1,6 @@
 *.o
 *.a
+*.so
 .DS_Store
 .build/
 .cache/
@@ -39,8 +40,8 @@ models/*
 /vdot
 /server
 /Pipfile
+/embd-input-test
 /libllama.so
-
 build-info.h
 arm_neon.h
 compile_commands.json
index bda11791d5e23a2979a03afe0a8d5e5f8c354067..03f38bdba04ec9dee583e6219d9ca768d53b954d 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -1,5 +1,5 @@
 # Define the default target now so that it is always the first target
-BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch simple
+BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch simple libembdinput.so embd-input-test
 
 ifdef LLAMA_BUILD_SERVER
        BUILD_TARGETS += server
@@ -272,7 +272,7 @@ libllama.so: llama.o ggml.o $(OBJS)
        $(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS)
 
 clean:
-       rm -vf *.o *.so main quantize quantize-stats perplexity embedding benchmark-matmult save-load-state server vdot train-text-from-scratch build-info.h
+       rm -vf *.o *.so main quantize quantize-stats perplexity embedding benchmark-matmult save-load-state server vdot train-text-from-scratch embd-input-test build-info.h
 
 #
 # Examples
@@ -305,6 +305,13 @@ save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.
 server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp build-info.h ggml.o llama.o common.o $(OBJS)
        $(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS)
 
+libembdinput.so: examples/embd-input/embd-input.h examples/embd-input/embd-input-lib.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+       $(CXX) --shared $(CXXFLAGS) $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS)
+
+
+embd-input-test: libembdinput.so examples/embd-input/embd-input-test.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+       $(CXX) $(CXXFLAGS) $(filter-out %.so,$(filter-out %.h,$(filter-out %.hpp,$^))) -o $@ $(LDFLAGS) -L. -lembdinput
+
 train-text-from-scratch: examples/train-text-from-scratch/train-text-from-scratch.cpp    build-info.h ggml.o llama.o $(OBJS)
        $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
 
index 9090e8d6dd55a76c3816f9985e9c036a195a6210..f43c836f577a6a8be278aaeae3a42b537b05b23e 100644 (file)
@@ -113,6 +113,10 @@ with open(output_path, "wb") as fout:
 
     write_file_header(fout, params)
     for k, v in model.items():
+        if k.endswith(".default.weight"):
+            k = k.replace(".default.weight", ".weight")
+        if k in ["llama_proj.weight", "llama_proj.bias"]:
+            continue
         if k.endswith("lora_A.weight"):
             if v.dtype != torch.float16 and v.dtype != torch.float32:
                 v = v.float()
@@ -120,7 +124,7 @@ with open(output_path, "wb") as fout:
         else:
             v = v.float()
 
-        t = v.numpy()
+        t = v.detach().numpy()
         tname = translate_tensor_name(k)
         print(f"{k} => {tname} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB")
         write_tensor_header(fout, tname, t.shape, t.dtype)
index cf9c4a223133785c74af3b1c3d281c91bcb75336..161960bb853cc01ef0d47f207942c5cd6541e5cc 100644 (file)
@@ -39,6 +39,7 @@ else()
     add_subdirectory(baby-llama)
     add_subdirectory(train-text-from-scratch)
     add_subdirectory(simple)
+    add_subdirectory(embd-input)
     if (LLAMA_METAL)
         add_subdirectory(metal)
     endif()
diff --git a/examples/embd-input/.gitignore b/examples/embd-input/.gitignore
new file mode 100644 (file)
index 0000000..87ef687
--- /dev/null
@@ -0,0 +1,4 @@
+PandaGPT
+MiniGPT-4
+*.pth
+
diff --git a/examples/embd-input/CMakeLists.txt b/examples/embd-input/CMakeLists.txt
new file mode 100644 (file)
index 0000000..2b62395
--- /dev/null
@@ -0,0 +1,15 @@
+set(TARGET embdinput)
+add_library(${TARGET} embd-input-lib.cpp embd-input.h)
+target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_11)
+if(TARGET BUILD_INFO)
+  add_dependencies(${TARGET} BUILD_INFO)
+endif()
+
+set(TARGET embd-input-test)
+add_executable(${TARGET} embd-input-test.cpp)
+target_link_libraries(${TARGET} PRIVATE common llama embdinput ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_11)
+if(TARGET BUILD_INFO)
+  add_dependencies(${TARGET} BUILD_INFO)
+endif()
diff --git a/examples/embd-input/README.md b/examples/embd-input/README.md
new file mode 100644 (file)
index 0000000..02d028f
--- /dev/null
@@ -0,0 +1,63 @@
+### Examples for input embedding directly
+
+## Requirement
+build  `libembdinput.so`
+run the following comman in main dir (../../).
+```
+make
+```
+
+## [LLaVA](https://github.com/haotian-liu/LLaVA/) example  (llava.py)
+
+1. Obtian LLaVA model (following https://github.com/haotian-liu/LLaVA/ , use https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/).
+2. Convert it to ggml format.
+3. `llava_projection.pth` is [pytorch_model-00003-of-00003.bin](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin).
+
+```
+import torch
+
+bin_path = "../LLaVA-13b-delta-v1-1/pytorch_model-00003-of-00003.bin"
+pth_path = "./examples/embd_input/llava_projection.pth"
+
+dic = torch.load(bin_path)
+used_key = ["model.mm_projector.weight","model.mm_projector.bias"]
+torch.save({k: dic[k] for k in used_key}, pth_path)
+```
+4. Check the path of LLaVA model and `llava_projection.pth` in `llava.py`.
+
+
+## [PandaGPT](https://github.com/yxuansu/PandaGPT) example (panda_gpt.py)
+
+1. Obtian PandaGPT lora model from https://github.com/yxuansu/PandaGPT. Rename the file to `adapter_model.bin`. Use [convert-lora-to-ggml.py](../../convert-lora-to-ggml.py) to convert it to ggml format.
+The `adapter_config.json` is
+```
+{
+  "peft_type": "LORA",
+  "fan_in_fan_out": false,
+  "bias": null,
+  "modules_to_save": null,
+  "r": 32,
+  "lora_alpha": 32,
+  "lora_dropout": 0.1,
+  "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"]
+}
+```
+2. Papare the `vicuna` v0 model.
+3. Obtain the [ImageBind](https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth) model.
+4. Clone the PandaGPT source.
+```
+git clone https://github.com/yxuansu/PandaGPT
+```
+5. Install the requirement of PandaGPT.
+6. Check the path of PandaGPT source, ImageBind model, lora model and vicuna model in panda_gpt.py.
+
+## [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4/) example (minigpt4.py)
+
+1. Obtain MiniGPT-4 model from https://github.com/Vision-CAIR/MiniGPT-4/ and put it in `embd-input`.
+2. Clone the MiniGPT-4 source.
+```
+git clone https://github.com/Vision-CAIR/MiniGPT-4/
+```
+3. Install the requirement of PandaGPT.
+4. Papare the `vicuna` v0 model.
+5. Check the path of MiniGPT-4 source, MiniGPT-4 model and vicuna model in `minigpt4.py`.
diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp
new file mode 100644 (file)
index 0000000..37de52a
--- /dev/null
@@ -0,0 +1,220 @@
+// Defines sigaction on msys:
+#ifndef _GNU_SOURCE
+#define _GNU_SOURCE
+#endif
+
+#include "embd-input.h"
+
+#include <cassert>
+#include <cinttypes>
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <ctime>
+#include <fstream>
+#include <iostream>
+#include <string>
+#include <vector>
+
+static llama_context ** g_ctx;
+
+extern "C" {
+
+struct MyModel* create_mymodel(int argc, char ** argv) {
+    gpt_params params;
+
+    if (gpt_params_parse(argc, argv, params) == false) {
+        return nullptr;
+    }
+
+    fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);
+
+    if (params.seed < 0) {
+        params.seed = time(NULL);
+    }
+    fprintf(stderr, "%s: seed  = %d\n", __func__, params.seed);
+
+    llama_init_backend(params.numa);
+
+    llama_model * model;
+    llama_context * ctx;
+
+    g_ctx = &ctx;
+
+    // load the model and apply lora adapter, if any
+    std::tie(model, ctx) = llama_init_from_gpt_params(params);
+    if (model == NULL) {
+        fprintf(stderr, "%s: error: unable to load model\n", __func__);
+        return nullptr;
+    }
+
+    // print system information
+    {
+        fprintf(stderr, "\n");
+        fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
+                params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
+    }
+    struct MyModel * ret = new MyModel();
+    ret->ctx = ctx;
+    ret->params = params;
+    ret->n_past = 0;
+    // printf("ctx: %d\n", ret->ctx);
+    return ret;
+}
+
+void free_mymodel(struct MyModel * mymodel) {
+    llama_context * ctx = mymodel->ctx;
+    llama_print_timings(ctx);
+    llama_free(ctx);
+    delete mymodel;
+}
+
+
+bool eval_float(void * model, float * input, int N){
+    MyModel * mymodel = (MyModel*)model;
+    llama_context * ctx = mymodel->ctx;
+    gpt_params params = mymodel->params;
+    int n_emb = llama_n_embd(ctx);
+    int n_past = mymodel->n_past;
+    int n_batch = N; // params.n_batch;
+
+    for (int i = 0; i < (int) N; i += n_batch) {
+        int n_eval = (int) N - i;
+        if (n_eval > n_batch) {
+            n_eval = n_batch;
+        }
+        if (llama_eval_embd(ctx, (input+i*n_emb), n_eval, n_past, params.n_threads)) {
+            fprintf(stderr, "%s : failed to eval\n", __func__);
+            return false;
+        }
+        n_past += n_eval;
+    }
+    mymodel->n_past = n_past;
+    return true;
+}
+
+bool eval_tokens(void * model, std::vector<llama_token> tokens) {
+    MyModel * mymodel = (MyModel* )model;
+    llama_context * ctx;
+    ctx = mymodel->ctx;
+    gpt_params params = mymodel->params;
+    int n_past = mymodel->n_past;
+    for (int i = 0; i < (int) tokens.size(); i += params.n_batch) {
+        int n_eval = (int) tokens.size() - i;
+        if (n_eval > params.n_batch) {
+            n_eval = params.n_batch;
+        }
+        if (llama_eval(ctx, &tokens[i], n_eval, n_past, params.n_threads)) {
+            fprintf(stderr, "%s : failed to eval\n", __func__);
+            return false;
+        }
+        n_past += n_eval;
+    }
+    mymodel->n_past = n_past;
+    return true;
+}
+
+bool eval_id(struct MyModel* mymodel, int id) {
+    std::vector<llama_token> tokens;
+    tokens.push_back(id);
+    return eval_tokens(mymodel, tokens);
+}
+
+bool eval_string(struct MyModel * mymodel,const char* str){
+    llama_context * ctx = mymodel->ctx;
+    std::string str2 = str;
+    std::vector<llama_token> embd_inp = ::llama_tokenize(ctx, str2, true);
+    eval_tokens(mymodel, embd_inp);
+    return true;
+}
+
+llama_token sampling_id(struct MyModel* mymodel) {
+    llama_context* ctx = mymodel->ctx;
+    gpt_params params = mymodel->params;
+    // int n_ctx = llama_n_ctx(ctx);
+
+    // out of user input, sample next token
+    const float   temp            = params.temp;
+    const int32_t top_k           = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
+    const float   top_p           = params.top_p;
+    const float   tfs_z           = params.tfs_z;
+    const float   typical_p       = params.typical_p;
+    // const int32_t repeat_last_n   = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
+    // const float   repeat_penalty  = params.repeat_penalty;
+    // const float   alpha_presence  = params.presence_penalty;
+    // const float   alpha_frequency = params.frequency_penalty;
+    const int     mirostat        = params.mirostat;
+    const float   mirostat_tau    = params.mirostat_tau;
+    const float   mirostat_eta    = params.mirostat_eta;
+    // const bool    penalize_nl     = params.penalize_nl;
+
+    llama_token id = 0;
+    {
+        auto logits  = llama_get_logits(ctx);
+        auto n_vocab = llama_n_vocab(ctx);
+
+        // Apply params.logit_bias map
+        for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
+            logits[it->first] += it->second;
+        }
+
+        std::vector<llama_token_data> candidates;
+        candidates.reserve(n_vocab);
+        for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+            candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+        }
+
+        llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+
+        // TODO: Apply penalties
+        // float nl_logit = logits[llama_token_nl()];
+        // auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
+        // llama_sample_repetition_penalty(ctx, &candidates_p,
+        //      last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
+        //      last_n_repeat, repeat_penalty);
+        // llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
+        // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
+        // last_n_repeat, alpha_frequency, alpha_presence);
+        // if (!penalize_nl) {
+        //     logits[llama_token_nl()] = nl_logit;
+        // }
+
+        if (temp <= 0) {
+            // Greedy sampling
+            id = llama_sample_token_greedy(ctx, &candidates_p);
+        } else {
+            if (mirostat == 1) {
+                static float mirostat_mu = 2.0f * mirostat_tau;
+                const int mirostat_m = 100;
+                llama_sample_temperature(ctx, &candidates_p, temp);
+                id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
+            } else if (mirostat == 2) {
+                static float mirostat_mu = 2.0f * mirostat_tau;
+                llama_sample_temperature(ctx, &candidates_p, temp);
+                id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
+            } else {
+                // Temperature sampling
+                llama_sample_top_k(ctx, &candidates_p, top_k, 1);
+                llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
+                llama_sample_typical(ctx, &candidates_p, typical_p, 1);
+                llama_sample_top_p(ctx, &candidates_p, top_p, 1);
+                llama_sample_temperature(ctx, &candidates_p, temp);
+                id = llama_sample_token(ctx, &candidates_p);
+            }
+        }
+    }
+
+    return id;
+}
+
+const char * sampling(struct MyModel * mymodel) {
+    llama_context * ctx = mymodel->ctx;
+    int id = sampling_id(mymodel);
+    std::string ret;
+    if (id == llama_token_eos()) ret = "</s>";
+    else ret = llama_token_to_str(ctx, id);
+    eval_id(mymodel, id);
+    return ret.c_str();
+}
+
+}
diff --git a/examples/embd-input/embd-input-test.cpp b/examples/embd-input/embd-input-test.cpp
new file mode 100644 (file)
index 0000000..e5e040f
--- /dev/null
@@ -0,0 +1,35 @@
+#include "embd-input.h"
+#include <stdlib.h>
+#include <random>
+#include <string.h>
+
+int main(int argc, char** argv) {
+
+    auto mymodel = create_mymodel(argc, argv);
+    int N = 10;
+    int max_tgt_len = 500;
+    int n_embd = llama_n_embd(mymodel->ctx);
+
+    // add random float embd to test evaluation
+    float * data = new float[N*n_embd];
+    std::default_random_engine e;
+    std::uniform_real_distribution<float>  u(0,1);
+    for (int i=0;i<N*n_embd;i++) {
+        data[i] = u(e);
+    }
+
+    eval_string(mymodel, "user: what is the color of the flag of UN?");
+    eval_float(mymodel, data, N);
+    eval_string(mymodel, "assistant:");
+    eval_string(mymodel, mymodel->params.prompt.c_str());
+    const char* tmp;
+    for (int i=0; i<max_tgt_len; i++) {
+        tmp = sampling(mymodel);
+        if (strcmp(tmp, "</s>")==0) break;
+        printf("%s", tmp);
+        fflush(stdout);
+    }
+    printf("\n");
+    free_mymodel(mymodel);
+    return 0;
+}
diff --git a/examples/embd-input/embd-input.h b/examples/embd-input/embd-input.h
new file mode 100644 (file)
index 0000000..4fefabd
--- /dev/null
@@ -0,0 +1,30 @@
+#ifndef _EMBD_INPUT_H_
+#define _EMBD_INPUT_H_ 1
+
+#include "common.h"
+#include "llama.h"
+#include "build-info.h"
+
+
+extern "C" {
+
+typedef struct MyModel {
+    llama_context* ctx;
+    gpt_params params;
+    int n_past = 0;
+} MyModel;
+
+
+struct MyModel* create_mymodel(int argc, char ** argv);
+
+bool eval_float(void* model, float* input, int N);
+bool eval_tokens(void* model, std::vector<llama_token> tokens);
+bool eval_id(struct MyModel* mymodel, int id);
+bool eval_string(struct MyModel* mymodel, const char* str);
+const char* sampling(struct MyModel* mymodel);
+llama_token sampling_id(struct MyModel* mymodel);
+void free_mymodel(struct MyModel* mymodel);
+
+}
+
+#endif
diff --git a/examples/embd-input/embd_input.py b/examples/embd-input/embd_input.py
new file mode 100644 (file)
index 0000000..be28966
--- /dev/null
@@ -0,0 +1,71 @@
+import ctypes
+from ctypes import cdll, c_char_p, c_void_p, POINTER, c_float, c_int
+import numpy as np
+import os
+
+libc = cdll.LoadLibrary("./libembdinput.so")
+libc.sampling.restype=c_char_p
+libc.create_mymodel.restype=c_void_p
+libc.eval_string.argtypes=[c_void_p, c_char_p]
+libc.sampling.argtypes=[c_void_p]
+libc.eval_float.argtypes=[c_void_p, POINTER(c_float), c_int]
+
+
+class MyModel:
+    def __init__(self, args):
+        argc = len(args)
+        c_str = [c_char_p(i.encode()) for i in args]
+        args_c = (c_char_p * argc)(*c_str)
+        self.model = c_void_p(libc.create_mymodel(argc, args_c))
+        self.max_tgt_len = 512
+        self.print_string_eval = True
+
+    def __del__(self):
+        libc.free_mymodel(self.model)
+
+    def eval_float(self, x):
+        libc.eval_float(self.model, x.astype(np.float32).ctypes.data_as(POINTER(c_float)), x.shape[1])
+
+    def eval_string(self, x):
+        libc.eval_string(self.model, x.encode()) # c_char_p(x.encode()))
+        if self.print_string_eval:
+            print(x)
+
+    def eval_token(self, x):
+        libc.eval_id(self.model, x)
+
+    def sampling(self):
+        s = libc.sampling(self.model)
+        return s
+
+    def stream_generate(self, end="</s>"):
+        ret = b""
+        end = end.encode()
+        for _ in range(self.max_tgt_len):
+            tmp = self.sampling()
+            ret += tmp
+            yield tmp
+            if ret.endswith(end):
+                break
+
+    def generate_with_print(self, end="</s>"):
+        ret = b""
+        for i in self.stream_generate(end=end):
+            ret += i
+            print(i.decode(errors="replace"), end="", flush=True)
+        print("")
+        return ret.decode(errors="replace")
+
+
+    def generate(self, end="</s>"):
+        text = b"".join(self.stream_generate(end=end))
+        return text.decode(errors="replace")
+
+if __name__ == "__main__":
+    model = MyModel(["main", "--model", "../llama.cpp/models/ggml-vic13b-q4_1.bin", "-c", "2048"])
+    model.eval_string("""user: what is the color of the flag of UN?""")
+    x = np.random.random((5120,10))# , dtype=np.float32)
+    model.eval_float(x)
+    model.eval_string("""assistant:""")
+    for i in model.generate():
+        print(i.decode(errors="replace"), end="", flush=True)
diff --git a/examples/embd-input/llava.py b/examples/embd-input/llava.py
new file mode 100644 (file)
index 0000000..2f20cb7
--- /dev/null
@@ -0,0 +1,70 @@
+import sys
+import os
+sys.path.insert(0, os.path.dirname(__file__))
+from embd_input import MyModel
+import numpy as np
+from torch import nn
+import torch
+from transformers import CLIPVisionModel,  CLIPImageProcessor
+from PIL import Image
+
+# model parameters from 'liuhaotian/LLaVA-13b-delta-v1-1'
+vision_tower = "openai/clip-vit-large-patch14"
+select_hidden_state_layer = -2
+# (vision_config.image_size // vision_config.patch_size) ** 2
+image_token_len = (224//14)**2
+
+class Llava:
+    def __init__(self, args):
+        self.image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
+        self.vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
+        self.mm_projector = nn.Linear(1024, 5120)
+        self.model = MyModel(["main", *args])
+
+    def load_projection(self, path):
+        state = torch.load(path)
+        self.mm_projector.load_state_dict({
+            "weight": state["model.mm_projector.weight"],
+            "bias": state["model.mm_projector.bias"]})
+
+    def chat(self, question):
+        self.model.eval_string("user: ")
+        self.model.eval_string(question)
+        self.model.eval_string("\nassistant: ")
+        return self.model.generate_with_print()
+
+    def chat_with_image(self, image, question):
+        with torch.no_grad():
+            embd_image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
+            image_forward_out = self.vision_tower(embd_image.unsqueeze(0), output_hidden_states=True)
+            select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer]
+            image_feature = select_hidden_state[:, 1:]
+            embd_image = self.mm_projector(image_feature)
+            embd_image = embd_image.cpu().numpy()[0]
+        self.model.eval_string("user: ")
+        self.model.eval_token(32003-2) # im_start
+        self.model.eval_float(embd_image.T)
+        for i in range(image_token_len-embd_image.shape[0]):
+            self.model.eval_token(32003-3) # im_patch
+        self.model.eval_token(32003-1) # im_end
+        self.model.eval_string(question)
+        self.model.eval_string("\nassistant: ")
+        return self.model.generate_with_print()
+
+
+if __name__=="__main__":
+    # model form liuhaotian/LLaVA-13b-delta-v1-1
+    a = Llava(["--model", "./models/ggml-llava-13b-v1.1.bin", "-c", "2048"])
+    # Extract from https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin.
+    # Also here can use pytorch_model-00003-of-00003.bin directly.
+    a.load_projection(os.path.join(
+        os.path.dirname(__file__) ,
+        "llava_projetion.pth"))
+    respose = a.chat_with_image(
+        Image.open("./media/llama1-logo.png").convert('RGB'),
+        "what is the text in the picture?")
+    respose
+    a.chat("what is the color of it?")
+
+
+
diff --git a/examples/embd-input/minigpt4.py b/examples/embd-input/minigpt4.py
new file mode 100644 (file)
index 0000000..8e98f85
--- /dev/null
@@ -0,0 +1,128 @@
+import sys
+import os
+sys.path.insert(0, os.path.dirname(__file__))
+from embd_input import MyModel
+import numpy as np
+from torch import nn
+import torch
+from PIL import Image
+
+minigpt4_path = os.path.join(os.path.dirname(__file__), "MiniGPT-4")
+sys.path.insert(0, minigpt4_path)
+from minigpt4.models.blip2 import Blip2Base
+from minigpt4.processors.blip_processors import Blip2ImageEvalProcessor
+
+
+class MiniGPT4(Blip2Base):
+    """
+    MiniGPT4 model from https://github.com/Vision-CAIR/MiniGPT-4
+    """
+    def __init__(self,
+        args,
+        vit_model="eva_clip_g",
+        q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
+        img_size=224,
+        drop_path_rate=0,
+        use_grad_checkpoint=False,
+        vit_precision="fp32",
+        freeze_vit=True,
+        freeze_qformer=True,
+        num_query_token=32,
+        llama_model="",
+        prompt_path="",
+        prompt_template="",
+        max_txt_len=32,
+        end_sym='\n',
+        low_resource=False,  # use 8 bit and put vit in cpu
+        device_8bit=0
+    ):
+        super().__init__()
+        self.img_size = img_size
+        self.low_resource = low_resource
+        self.preprocessor = Blip2ImageEvalProcessor(img_size)
+
+        print('Loading VIT')
+        self.visual_encoder, self.ln_vision = self.init_vision_encoder(
+            vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
+        )
+        print('Loading VIT Done')
+        print('Loading Q-Former')
+        self.Qformer, self.query_tokens = self.init_Qformer(
+            num_query_token, self.visual_encoder.num_features
+        )
+        self.Qformer.cls = None
+        self.Qformer.bert.embeddings.word_embeddings = None
+        self.Qformer.bert.embeddings.position_embeddings = None
+        for layer in self.Qformer.bert.encoder.layer:
+            layer.output = None
+            layer.intermediate = None
+        self.load_from_pretrained(url_or_filename=q_former_model)
+        print('Loading Q-Former Done')
+        self.llama_proj = nn.Linear(
+            self.Qformer.config.hidden_size, 5120 # self.llama_model.config.hidden_size
+        )
+        self.max_txt_len = max_txt_len
+        self.end_sym = end_sym
+        self.model = MyModel(["main", *args])
+        # system promt
+        self.model.eval_string("Give the following image: <Img>ImageContent</Img>. "
+           "You will be able to see the image once I provide it to you. Please answer my questions."
+           "###")
+
+    def encode_img(self, image):
+        image = self.preprocessor(image)
+        image = image.unsqueeze(0)
+        device = image.device
+        if self.low_resource:
+            self.vit_to_cpu()
+            image = image.to("cpu")
+
+        with self.maybe_autocast():
+            image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
+            image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
+
+            query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
+            query_output = self.Qformer.bert(
+                query_embeds=query_tokens,
+                encoder_hidden_states=image_embeds,
+                encoder_attention_mask=image_atts,
+                return_dict=True,
+            )
+
+            inputs_llama = self.llama_proj(query_output.last_hidden_state)
+            # atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
+        return inputs_llama
+
+    def load_projection(self, path):
+        state = torch.load(path)["model"]
+        self.llama_proj.load_state_dict({
+            "weight": state["llama_proj.weight"],
+            "bias": state["llama_proj.bias"]})
+
+    def chat(self, question):
+        self.model.eval_string("Human: ")
+        self.model.eval_string(question)
+        self.model.eval_string("\n### Assistant:")
+        return self.model.generate_with_print(end="###")
+
+    def chat_with_image(self, image, question):
+        with torch.no_grad():
+            embd_image = self.encode_img(image)
+        embd_image = embd_image.cpu().numpy()[0]
+        self.model.eval_string("Human: <Img>")
+        self.model.eval_float(embd_image.T)
+        self.model.eval_string("</Img> ")
+        self.model.eval_string(question)
+        self.model.eval_string("\n### Assistant:")
+        return self.model.generate_with_print(end="###")
+
+
+if __name__=="__main__":
+    a = MiniGPT4(["--model", "./models/ggml-vicuna-13b-v0-q4_1.bin", "-c", "2048"])
+    a.load_projection(os.path.join(
+        os.path.dirname(__file__) ,
+        "pretrained_minigpt4.pth"))
+    respose = a.chat_with_image(
+        Image.open("./media/llama1-logo.png").convert('RGB'),
+        "what is the text in the picture?")
+    a.chat("what is the color of it?")
diff --git a/examples/embd-input/panda_gpt.py b/examples/embd-input/panda_gpt.py
new file mode 100644 (file)
index 0000000..0cfac5f
--- /dev/null
@@ -0,0 +1,98 @@
+import sys
+import os
+sys.path.insert(0, os.path.dirname(__file__))
+from embd_input import MyModel
+import numpy as np
+from torch import nn
+import torch
+
+# use PandaGPT path
+panda_gpt_path = os.path.join(os.path.dirname(__file__), "PandaGPT")
+imagebind_ckpt_path = "./models/panda_gpt/"
+
+sys.path.insert(0, os.path.join(panda_gpt_path,"code","model"))
+from ImageBind.models import imagebind_model
+from ImageBind import data
+
+ModalityType = imagebind_model.ModalityType
+max_tgt_len = 400
+
+class PandaGPT:
+    def __init__(self, args):
+        self.visual_encoder,_ = imagebind_model.imagebind_huge(pretrained=True, store_path=imagebind_ckpt_path)
+        self.visual_encoder.eval()
+        self.llama_proj = nn.Linear(1024, 5120) # self.visual_hidden_size, 5120)
+        self.max_tgt_len = max_tgt_len
+        self.model = MyModel(["main", *args])
+        self.generated_text = ""
+        self.device = "cpu"
+
+    def load_projection(self, path):
+        state = torch.load(path, map_location="cpu")
+        self.llama_proj.load_state_dict({
+            "weight": state["llama_proj.weight"],
+            "bias": state["llama_proj.bias"]})
+
+    def eval_inputs(self, inputs):
+        self.model.eval_string("<Img>")
+        embds = self.extract_multimoal_feature(inputs)
+        for i in embds:
+            self.model.eval_float(i.T)
+        self.model.eval_string("</Img> ")
+
+    def chat(self, question):
+        return self.chat_with_image(None, question)
+
+    def chat_with_image(self, inputs, question):
+        if self.generated_text == "":
+            self.model.eval_string("###")
+        self.model.eval_string(" Human: ")
+        if inputs:
+            self.eval_inputs(inputs)
+        self.model.eval_string(question)
+        self.model.eval_string("\n### Assistant:")
+        ret = self.model.generate_with_print(end="###")
+        self.generated_text += ret
+        return ret
+
+    def extract_multimoal_feature(self, inputs):
+        features = []
+        for key in ["image", "audio", "video", "thermal"]:
+            if key + "_paths" in inputs:
+                embeds = self.encode_data(key, inputs[key+"_paths"])
+                features.append(embeds)
+        return features
+
+    def encode_data(self, data_type, data_paths):
+
+        type_map = {
+            "image": ModalityType.VISION,
+            "audio": ModalityType.AUDIO,
+            "video": ModalityType.VISION,
+            "thermal": ModalityType.THERMAL,
+        }
+        load_map = {
+            "image": data.load_and_transform_vision_data,
+            "audio": data.load_and_transform_audio_data,
+            "video": data.load_and_transform_video_data,
+            "thermal": data.load_and_transform_thermal_data
+        }
+
+        load_function = load_map[data_type]
+        key = type_map[data_type]
+
+        inputs = {key: load_function(data_paths, self.device)}
+        with torch.no_grad():
+            embeddings = self.visual_encoder(inputs)
+            embeds = embeddings[key]
+            embeds = self.llama_proj(embeds).cpu().numpy()
+        return embeds
+
+
+if __name__=="__main__":
+    a = PandaGPT(["--model", "./models/ggml-vicuna-13b-v0-q4_1.bin", "-c", "2048", "--lora", "./models/panda_gpt/ggml-adapter-model.bin","--temp", "0"])
+    a.load_projection("./models/panda_gpt/adapter_model.bin")
+    a.chat_with_image(
+        {"image_paths": ["./media/llama1-logo.png"]},
+        "what is the text in the picture? 'llama' or 'lambda'?")
+    a.chat("what is the color of it?")
index 2482bdd18d2e71da15ef76ff244691ef0ae1d0d7..5a142aba642811d14e46fc1683ccfe31db0f1b99 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1369,22 +1369,26 @@ static bool llama_model_load(
 
 // evaluate the transformer
 //
-//   - lctx:         llama context
-//   - tokens:       new batch of tokens to process
-//   - n_past:       the context size so far
-//   - n_threads:    number of threads to use
-//   - cgraph_fname: filename of the exported computation graph
+//   - lctx:      llama context
+//   - tokens:    new batch of tokens to process
+//   - embd       embeddings input
+//   - n_tokens   number of tokens
+//   - n_past:    the context size so far
+//   - n_threads: number of threads to use
 //
 static bool llama_eval_internal(
-        llama_context &  lctx,
-    const llama_token *  tokens,
-            const int    n_tokens,
-            const int    n_past,
-            const int    n_threads,
+         llama_context & lctx,
+     const llama_token * tokens,
+           const float * embd,
+             const int   n_tokens,
+             const int   n_past,
+             const int   n_threads,
             const char * cgraph_fname) {
 
+    LLAMA_ASSERT((!tokens && embd) || (tokens && !embd));
+
     // enforce that the first token is BOS
-    if (n_past == 0 && tokens[0] != llama_token_bos()) {
+    if (tokens && n_past == 0 && tokens[0] != llama_token_bos()) {
         fprintf(stderr, "%s: first token must be BOS\n", __func__);
         return false;
     }
@@ -1424,12 +1428,18 @@ static bool llama_eval_internal(
     ggml_cgraph gf = {};
     gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads;
 
-    struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
-    ggml_set_name(embd, "embd");
-    memcpy(embd->data, tokens, N*ggml_element_size(embd));
-
     struct ggml_tensor * cur;
-    struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
+    struct ggml_tensor * inpL;
+
+    if (tokens) {
+        struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+        ggml_set_name(embd, "embd");
+        memcpy(embd->data, tokens, N*ggml_element_size(embd));
+        inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
+    } else {
+        inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N);
+        memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
+    }
 
     const int i_gpu_start = n_layer - n_gpu_layers;
     (void) i_gpu_start;
@@ -2654,6 +2664,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
     }
 }
 
+
+
 //
 // interface implementation
 //
@@ -3421,7 +3433,29 @@ int llama_eval(
                          int   n_tokens,
                          int   n_past,
                          int   n_threads) {
-    if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads, nullptr)) {
+    if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, n_threads, nullptr)) {
+        fprintf(stderr, "%s: failed to eval\n", __func__);
+        return 1;
+    }
+
+    // get a more accurate load time, upon first eval
+    // TODO: fix this
+    if (!ctx->has_evaluated_once) {
+        ctx->t_load_us = ggml_time_us() - ctx->t_start_us;
+        ctx->has_evaluated_once = true;
+    }
+
+    return 0;
+}
+
+
+int llama_eval_embd(
+            struct llama_context * ctx,
+                     const float * embd,
+                             int   n_tokens,
+                             int   n_past,
+                             int   n_threads) {
+    if (!llama_eval_internal(*ctx, nullptr, embd, n_tokens, n_past, n_threads, nullptr)) {
         fprintf(stderr, "%s: failed to eval\n", __func__);
         return 1;
     }
@@ -3442,7 +3476,7 @@ int llama_eval_export(struct llama_context * ctx, const char * fname) {
 
     const std::vector<llama_token> tmp(n_batch, llama_token_bos());
 
-    if (!llama_eval_internal(*ctx, tmp.data(), tmp.size(), n_ctx, 1, fname)) {
+    if (!llama_eval_internal(*ctx, tmp.data(), nullptr, tmp.size(), n_ctx, 1, fname)) {
         fprintf(stderr, "%s: failed to eval\n", __func__);
         return 1;
     }
diff --git a/llama.h b/llama.h
index 76239be25fc220f8128b6a38ae302ce3ae7aeefd..c2f2e53312b9e36261c2b65ab0982047bc44f8dc 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -226,6 +226,14 @@ extern "C" {
                              int   n_past,
                              int   n_threads);
 
+    // Same as llama_eval, but use float matrix input directly.
+    LLAMA_API int llama_eval_embd(
+            struct llama_context * ctx,
+                     const float * embd,
+                             int   n_tokens,
+                             int   n_past,
+                             int   n_threads);
+
     // Export a static computation graph for context of 511 and batch size of 1
     // NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these
     //       parameters here to keep things simple