]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add support for control vectors (#5970)
authorTheia Vogel <redacted>
Fri, 15 Mar 2024 20:43:02 +0000 (13:43 -0700)
committerGitHub <redacted>
Fri, 15 Mar 2024 20:43:02 +0000 (22:43 +0200)
* control vector api and implementation

* control-vectors : minor code style updates

* disable control vector when data == nullptr

use -1 for disabled range (also on init) in case we ever support controlling layer 0 (embeddings)

---------

Co-authored-by: Georgi Gerganov <redacted>
common/common.cpp
common/common.h
llama.cpp
llama.h

index 58fbd05aa35165c0388448bc29595951e3d99522..4912237e0d0f115c077f46b37704c80ad7c46c40 100644 (file)
@@ -568,6 +568,34 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.lora_base = argv[i];
+        } else if (arg == "--control-vector") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.control_vectors.push_back({ 1.0f, argv[i], });
+        } else if (arg == "--control-vector-scaled") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            const char * fname = argv[i];
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.control_vectors.push_back({ std::stof(argv[i]), fname, });
+        } else if (arg == "--control-vector-layer-range") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.control_vector_layer_start = std::stoi(argv[i]);
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.control_vector_layer_end = std::stoi(argv[i]);
         } else if (arg == "--mmproj") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -1095,6 +1123,12 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  --lora FNAME          apply LoRA adapter (implies --no-mmap)\n");
     printf("  --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n");
     printf("  --lora-base FNAME     optional model to use as a base for the layers modified by the LoRA adapter\n");
+    printf("  --control-vector FNAME\n");
+    printf("                        add a control vector\n");
+    printf("  --control-vector-scaled FNAME S\n");
+    printf("                        add a control vector with user defined scaling S\n");
+    printf("  --control-vector-layer-range START END\n");
+    printf("                        layer range to apply the control vector(s) to, start and end inclusive\n");
     printf("  -m FNAME, --model FNAME\n");
     printf("                        model path (default: %s)\n", params.model.c_str());
     printf("  -md FNAME, --model-draft FNAME\n");
@@ -1360,6 +1394,30 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
         return std::make_tuple(nullptr, nullptr);
     }
 
+    if (!params.control_vectors.empty()) {
+        if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1;
+        if (params.control_vector_layer_end   <= 0) params.control_vector_layer_end   = llama_n_layer(model);
+
+        const auto cvec = llama_control_vector_load(params.control_vectors);
+        if (cvec.n_embd == -1) {
+            llama_free(lctx);
+            llama_free_model(model);
+            return std::make_tuple(nullptr, nullptr);
+        }
+
+        int err = llama_control_vector_apply(lctx,
+                                             cvec.data.data(),
+                                             cvec.data.size(),
+                                             cvec.n_embd,
+                                             params.control_vector_layer_start,
+                                             params.control_vector_layer_end);
+        if (err) {
+            llama_free(lctx);
+            llama_free_model(model);
+            return std::make_tuple(nullptr, nullptr);
+        }
+    }
+
     for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
         const std::string& lora_adapter = std::get<0>(params.lora_adapter[i]);
         float lora_scale = std::get<1>(params.lora_adapter[i]);
@@ -1890,3 +1948,160 @@ float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n)
 
     return sum / (sqrt(sum1) * sqrt(sum2));
 }
+
+//
+// Control vector utils
+//
+
+static llama_control_vector_data llama_control_vector_load_one(const llama_control_vector_load_info & load_info) {
+    int32_t n_tensors;
+
+    size_t n_bytes = 0;
+
+    uint32_t max_direction_layer = 0;
+
+    llama_control_vector_data result = { -1, {} };
+
+    // calculate size of ctx needed for tensors, ensure tensors are f32, and find max layer
+    {
+        struct ggml_init_params meta_params = {
+            /* .mem_size   = */ ggml_tensor_overhead() * 128 + ggml_graph_overhead(),
+            /* .mem_buffer = */ nullptr,
+            /* .no_alloc   = */ true,
+        };
+        ggml_context * meta_ctx = ggml_init(meta_params);
+        struct gguf_init_params meta_gguf_params = {
+            /* .no_alloc = */ true,
+            /* .ctx      = */ &meta_ctx,
+        };
+        struct gguf_context * meta_ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), meta_gguf_params);
+        if (!meta_ctx_gguf) {
+            fprintf(stderr, "%s: failed to load control vector from %s\n", __func__, load_info.fname.c_str());
+            ggml_free(meta_ctx);
+            return result;
+        }
+
+        n_tensors = gguf_get_n_tensors(meta_ctx_gguf);
+        for (int i = 0; i < n_tensors; i++) {
+            std::string name = gguf_get_tensor_name(meta_ctx_gguf, i);
+
+            // split on '.'
+            size_t dotpos = name.find('.');
+            if (dotpos != std::string::npos && name.substr(0, dotpos) == "direction") {
+                try {
+                    uint32_t layer = std::stoi(name.substr(dotpos + 1));
+                    if (layer == 0) {
+                        fprintf(stderr, "%s: direction tensor invalid in %s\n", __func__, load_info.fname.c_str());
+                        ggml_free(meta_ctx);
+                        gguf_free(meta_ctx_gguf);
+                        return result;
+                    }
+                    if (layer > max_direction_layer) {
+                        max_direction_layer = layer;
+                    }
+                } catch (...) {
+                    fprintf(stderr, "%s: direction tensor invalid in %s\n", __func__, load_info.fname.c_str());
+                    ggml_free(meta_ctx);
+                    gguf_free(meta_ctx_gguf);
+                    return result;
+                }
+            }
+
+            struct ggml_tensor * tensor_meta = ggml_get_tensor(meta_ctx, name.c_str());
+            if (tensor_meta->type != GGML_TYPE_F32 || ggml_n_dims(tensor_meta) != 1) {
+                fprintf(stderr, "%s: direction tensor invalid in %s\n", __func__, load_info.fname.c_str());
+                ggml_free(meta_ctx);
+                gguf_free(meta_ctx_gguf);
+                return result;
+            }
+            if (result.n_embd == -1) {
+                result.n_embd = ggml_nelements(tensor_meta);
+            } else if (ggml_nelements(tensor_meta) != result.n_embd) {
+                fprintf(stderr, "%s: direction tensor sizes mismatched in %s\n", __func__, load_info.fname.c_str());
+                ggml_free(meta_ctx);
+                gguf_free(meta_ctx_gguf);
+                return result;
+            }
+            n_bytes += ggml_nbytes(tensor_meta);
+        }
+        ggml_free(meta_ctx);
+        gguf_free(meta_ctx_gguf);
+    }
+
+    if (n_tensors == 0) {
+        fprintf(stderr, "%s: no direction tensors found in %s\n", __func__, load_info.fname.c_str());
+        return result;
+    }
+
+    // load and scale tensors into final control vector context
+    struct ggml_init_params ggml_params = {
+        /* .mem_size   = */ ggml_tensor_overhead() * n_tensors + n_bytes,
+        /* .mem_buffer = */ nullptr,
+        /* .no_alloc   = */ false,
+    };
+    struct ggml_context * ctx = ggml_init(ggml_params);
+
+    struct gguf_init_params params = {
+        /*.no_alloc = */ false,
+        /*.ctx      = */ &ctx,
+    };
+    struct gguf_context * ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), params);
+    if (!ctx_gguf) {
+        fprintf(stderr, "%s: failed to load control vector from %s\n", __func__, load_info.fname.c_str());
+        ggml_free(ctx);
+        return result;
+    }
+
+    // do not store data for layer 0 (it's not used)
+    result.data.resize(result.n_embd * max_direction_layer);
+
+    for (uint32_t il = 1; il <= max_direction_layer; il++) {
+        const std::string name = "direction." + std::to_string(il);
+        const ggml_tensor * tensor = ggml_get_tensor(ctx, name.c_str());
+
+        float * dst = result.data.data() + result.n_embd * (il - 1);
+
+        if (tensor) {
+            const float * src = (const float *) tensor->data;
+            for (int j = 0; j < result.n_embd; j++) {
+                dst[j] = src[j] * load_info.strength;
+            }
+        } else {
+            for (int j = 0; j < result.n_embd; j++) {
+                dst[j] = 0.0f;
+            }
+        }
+    }
+
+    return result;
+}
+
+llama_control_vector_data llama_control_vector_load(const std::vector<llama_control_vector_load_info> & load_infos) {
+    llama_control_vector_data result = { -1, {} };
+
+    for (const auto & info : load_infos) {
+        auto cur = llama_control_vector_load_one(info);
+
+        if (cur.n_embd == -1) {
+            return result;
+        }
+        if (result.n_embd != -1 && (result.n_embd != cur.n_embd || result.data.size() != cur.data.size())) {
+            fprintf(stderr, "%s: control vector in %s does not match previous vector dimensions\n", __func__, info.fname.c_str());
+            return result;
+        }
+
+        if (result.n_embd == -1) {
+            result = std::move(cur);
+        } else {
+            for (size_t i = 0; i < cur.data.size(); i++) {
+                result.data[i] += cur.data[i];
+            }
+        }
+    }
+
+    if (result.n_embd == -1) {
+        fprintf(stderr, "%s: no vectors passed\n", __func__);
+    }
+
+    return result;
+}
index d250eef8b2b6b4e5b2f20cfc91c66798b58bea83..687f3425e8544c804895512e83ee7701fcd266db 100644 (file)
@@ -37,10 +37,13 @@ extern char const *LLAMA_COMMIT;
 extern char const *LLAMA_COMPILER;
 extern char const *LLAMA_BUILD_TARGET;
 
+struct llama_control_vector_load_info;
+
+int32_t get_num_physical_cores();
+
 //
 // CLI argument parsing
 //
-int32_t get_num_physical_cores();
 
 struct gpt_params {
     uint32_t seed                 = LLAMA_DEFAULT_SEED; // RNG seed
@@ -103,6 +106,11 @@ struct gpt_params {
     std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
     std::string lora_base  = "";                              // base model path for the lora adapter
 
+    std::vector<llama_control_vector_load_info> control_vectors; // control vector with user defined scale
+
+    int32_t control_vector_layer_start = -1; // layer range for control vector
+    int32_t control_vector_layer_end   = -1; // layer range for control vector
+
     int  ppl_stride        = 0;     // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
     int  ppl_output_type   = 0;     // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
                                     //                                       (which is more convenient to use for plotting)
@@ -269,3 +277,24 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40
 void llama_embd_normalize(const float * inp, float * out, int n);
 
 float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n);
+
+//
+// Control vector utils
+//
+
+struct llama_control_vector_data {
+    int n_embd;
+
+    // stores data for layers [1, n_layer] where n_layer = data.size() / n_embd
+    std::vector<float> data;
+};
+
+struct llama_control_vector_load_info {
+    float strength;
+
+    std::string fname;
+};
+
+// Load control vectors, scale each by strength, and add them together.
+// On error, returns {-1, empty}
+llama_control_vector_data llama_control_vector_load(const std::vector<llama_control_vector_load_info> & load_infos);
index fc5dd5cb43a7ec4aee2c4cad148cea049f2c3256..52bd718ba89a5769dd6e9581689f1341abc2ed0d 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1894,6 +1894,31 @@ struct llama_kv_cache {
     }
 };
 
+struct llama_control_vector {
+    std::vector<struct ggml_tensor *> tensors; // per layer
+    std::vector<struct ggml_context *> ctxs;
+    std::vector<ggml_backend_buffer_t> bufs;
+
+    int32_t layer_start = -1;
+    int32_t layer_end   = -1;
+
+    ggml_tensor * tensor_for(int il) const {
+        if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) {
+            return nullptr;
+        }
+        return tensors[il];
+    }
+
+    ~llama_control_vector() {
+        for (struct ggml_context * ctx : ctxs) {
+            ggml_free(ctx);
+        }
+        for (ggml_backend_buffer_t buf : bufs) {
+            ggml_backend_buffer_free(buf);
+        }
+    }
+};
+
 struct llama_vocab {
     using id    = int32_t;
     using token = std::string;
@@ -2108,6 +2133,9 @@ struct llama_context {
     struct ggml_tensor * inp_s_mask;    // F32 [1, kv_size]
     struct ggml_tensor * inp_s_seq;     // I32 [kv_size, n_batch]
 
+    // control vectors
+    struct llama_control_vector cvec;
+
 #ifdef GGML_USE_MPI
     ggml_mpi_context * ctx_mpi = NULL;
 #endif
@@ -5931,6 +5959,12 @@ struct llm_build_context {
             }
 
             cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
+            if (layer_dir != nullptr) {
+                cur = ggml_add(ctx0, cur, layer_dir);
+            }
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -13366,6 +13400,10 @@ int32_t llama_n_embd(const struct llama_model * model) {
     return model->hparams.n_embd;
 }
 
+int32_t llama_n_layer(const struct llama_model * model) {
+    return model->hparams.n_layer;
+}
+
 float llama_rope_freq_scale_train(const struct llama_model * model) {
     return model->hparams.rope_freq_scale_train;
 }
@@ -13465,6 +13503,96 @@ int32_t llama_model_apply_lora_from_file(const struct llama_model * model, const
     }
 }
 
+static bool llama_control_vector_init(struct llama_control_vector & cvec, const llama_model & model) {
+    GGML_ASSERT(cvec.tensors.empty());
+    GGML_ASSERT(cvec.ctxs.empty());
+    GGML_ASSERT(cvec.bufs.empty());
+
+    // count layer buffer types
+    std::map<ggml_backend_buffer_type_t, int> buft_layer_count;
+    for (int64_t i = 0; i < model.hparams.n_layer; i++) {
+        buft_layer_count[model.buft_layer[i].buft]++;
+    }
+
+    // allocate contexts
+    std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
+    for (auto & it : buft_layer_count) {
+        int n_layers = it.second;
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ n_layers * ggml_tensor_overhead(),
+            /*.mem_buffer =*/ NULL,
+            /*.no_alloc   =*/ true,
+        };
+        ggml_context * ctx = ggml_init(params);
+        if (!ctx) {
+            LLAMA_LOG_ERROR("%s: failed to allocate context for control vector\n", __func__);
+            return 1;
+        }
+        ctx_map[it.first] = ctx;
+    }
+
+    // make tensors
+    cvec.tensors.push_back(nullptr); // there's never a tensor for layer 0
+    for (size_t il = 1; il < model.hparams.n_layer; il++) {
+        struct ggml_context * ctx = ctx_map.at(model.buft_layer[il].buft);
+        ggml_tensor * tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model.hparams.n_embd);
+        cvec.tensors.push_back(tensor);
+    }
+
+    // allocate tensors / buffers and zero
+    for (auto it : ctx_map) {
+        ggml_backend_buffer_type_t buft = it.first;
+        ggml_context * ctx = it.second;
+        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
+        if (!buf) {
+            LLAMA_LOG_ERROR("%s: failed to allocate buffer for control vector\n", __func__);
+            return false;
+        }
+        ggml_backend_buffer_clear(buf, 0);
+        cvec.ctxs.push_back(ctx);
+        cvec.bufs.push_back(buf);
+    }
+
+    return true;
+}
+
+int32_t llama_control_vector_apply(struct llama_context * lctx, const float * data, size_t len, int32_t n_embd, int32_t il_start, int32_t il_end) {
+    const llama_model & model = lctx->model;
+    llama_control_vector & cvec = lctx->cvec;
+
+    if (data == nullptr) {
+        // disable the current control vector (but leave allocated for later)
+        cvec.layer_start = -1;
+        cvec.layer_end   = -1;
+        return 0;
+    }
+
+    if (n_embd != (int) model.hparams.n_embd) {
+        LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__);
+        return 1;
+    }
+
+    if (cvec.tensors.empty()) {
+        if (!llama_control_vector_init(cvec, model)) {
+            return 1;
+        }
+    }
+
+    cvec.layer_start = il_start;
+    cvec.layer_end   = il_end;
+
+    for (size_t il = 1; il < model.hparams.n_layer; il++) {
+        assert(cvec.tensors[il] != nullptr);
+
+        const size_t off = n_embd * (il - 1); // buffer doesn't have data for layer 0, since it's never present
+        if (off + n_embd <= len) {
+            ggml_backend_tensor_set(cvec.tensors[il], data + off, 0, n_embd * ggml_element_size(cvec.tensors[il]));
+        }
+    }
+
+    return 0;
+}
+
 struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max) {
     struct llama_kv_cache_view result = {
         /*.n_cells            = */ 0,
diff --git a/llama.h b/llama.h
index 90aa5372e740b264a76169202228ac58cc0029b8..40dcf54e394f8c22b0300bea3aaac5f437a50474 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -388,6 +388,7 @@ extern "C" {
     LLAMA_API int32_t llama_n_vocab    (const struct llama_model * model);
     LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
     LLAMA_API int32_t llama_n_embd     (const struct llama_model * model);
+    LLAMA_API int32_t llama_n_layer    (const struct llama_model * model);
 
     // Get the model's RoPE frequency scaling factor
     LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
@@ -435,10 +436,24 @@ extern "C" {
     // Returns 0 on success
     LLAMA_API int32_t llama_model_apply_lora_from_file(
             const struct llama_model * model,
-                      const char * path_lora,
-                           float   scale,
-                      const char * path_base_model,
-                         int32_t   n_threads);
+                          const char * path_lora,
+                               float   scale,
+                          const char * path_base_model,
+                             int32_t   n_threads);
+
+    // Apply a loaded control vector to a llama_context, or if data is NULL, clear
+    // the currently loaded vector.
+    // n_embd should be the size of a single layer's control, and data should point
+    // to an n_embd x n_layers buffer starting from layer 1.
+    // il_start and il_end are the layer range the vector should apply to (both inclusive)
+    // See llama_control_vector_load in common to load a control vector.
+    LLAMA_API int32_t llama_control_vector_apply(
+            struct llama_context * lctx,
+                     const float * data,
+                          size_t   len,
+                         int32_t   n_embd,
+                         int32_t   il_start,
+                         int32_t   il_end);
 
     //
     // KV cache