#include <algorithm>
#include <cstdlib>
#include <cstring>
-#include <iostream>
#include <fstream>
#include <map>
#include <string>
#include <unordered_map>
#include <vector>
-struct backend_cli_args {
- const char * model = nullptr;
- const char * test = nullptr;
- const char * device = "cpu";
+struct test_args {
+ std::string model;
+ std::string test;
+ std::string device = "auto";
};
-struct test_model_context {
- llama_model_ptr model;
- llama_context_ptr ctx;
- int n_vocab = 0;
-
- std::unordered_map<llama_seq_id, int32_t> seq_positions;
- std::unordered_map<llama_seq_id, int32_t> last_batch_info;
+struct test_params {
+ llama_model_ptr model;
+};
- bool load_model(const backend_cli_args & args) {
- if (model) {
- return true;
- }
+static llama_model_ptr load_model(const test_args & args) {
+ auto mparams = llama_model_default_params();
- llama_backend_init();
+ ggml_backend_dev_t devs[2] = { nullptr, nullptr };
- auto mparams = llama_model_default_params();
+ if (args.device != "auto") {
+ if (args.device == "gpu") {
+ devs[0] = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU);
- ggml_backend_dev_t devs[2];
- if (std::string_view(args.device) == "gpu") {
- ggml_backend_dev_t gpu = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU);
- if (gpu == nullptr) {
+ if (devs[0] == nullptr) {
fprintf(stderr, "Error: GPU requested but not available\n");
- return false;
+ return nullptr;
}
- devs[0] = gpu;
- devs[1] = nullptr; // null terminator
- mparams.devices = devs;
+
mparams.n_gpu_layers = 999;
- } else if (std::string_view(args.device) == "cpu") {
- ggml_backend_dev_t cpu = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
- devs[0] = cpu;
- devs[1] = nullptr; // null terminator
- mparams.devices = devs;
+ } else if (args.device == "cpu") {
+ devs[0] = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
+
+ mparams.n_gpu_layers = 0;
+ } else {
+ fprintf(stderr, "Error: invalid device '%s'\n", args.device.c_str());
+ return nullptr;
}
+ mparams.devices = devs;
+
fprintf(stderr, "Using device: %s\n", ggml_backend_dev_name(devs[0]));
+ }
- model.reset(llama_model_load_from_file(args.model, mparams));
+ llama_model_ptr res;
- if (!model) {
- fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", args.model);
- return false;
- }
- n_vocab = llama_vocab_n_tokens(get_vocab());
- fprintf(stderr, "Vocabulary size: %d\n", n_vocab);
+ res.reset(llama_model_load_from_file(args.model.c_str(), mparams));
- return true;
+ if (!res) {
+ fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", args.model.c_str());
+ return nullptr;
}
- bool setup(const backend_cli_args & args, std::vector<llama_sampler_seq_config> & configs, int32_t n_seq_max = -1) {
- if (!model) {
- load_model(args);
- }
+ return res;
+}
- if (ctx) {
- return true;
- }
+struct test_context {
+ llama_context_ptr ctx;
+
+ int n_vocab = 0;
+
+ const llama_vocab * vocab = nullptr;
+
+ std::unordered_map<llama_seq_id, int32_t> seq_positions;
+ std::unordered_map<llama_seq_id, int32_t> last_batch_info;
+
+ test_context(const test_params & params, std::vector<llama_sampler_seq_config> & configs, int32_t n_seq_max = -1) {
+ auto * model = params.model.get();
+
+ GGML_ASSERT(model);
+ GGML_ASSERT(!ctx);
llama_context_params cparams = llama_context_default_params();
cparams.n_ctx = 512;
cparams.n_seq_max = n_seq_max;
}
- ctx.reset(llama_init_from_model(model.get(), cparams));
+ ctx.reset(llama_init_from_model(model, cparams));
if (!ctx) {
- fprintf(stderr, "Warning: failed to create context, skipping test\n");
- return false;
+ throw std::runtime_error("failed to create context");
}
+
llama_set_warmup(ctx.get(), false);
- return true;
+ vocab = llama_model_get_vocab(model);
+ n_vocab = llama_vocab_n_tokens(vocab);
}
bool decode(const std::map<llama_seq_id, std::string> & prompts) {
- if (!ctx) {
- fprintf(stderr, "Error: context not initialized, call setup() first\n");
- return false;
- }
+ GGML_ASSERT(ctx);
last_batch_info.clear();
llama_batch batch = llama_batch_init(512, 0, prompts.size());
- auto vocab = get_vocab();
for (const auto & [seq_id, prompt] : prompts) {
std::vector<llama_token> tokens;
tokens.push_back(llama_vocab_bos(vocab));
}
bool decode_token(llama_token token, llama_seq_id seq_id = 0) {
- if (ctx == nullptr) {
- fprintf(stderr, "Error: context not initialized, call setup() first\n");
- return false;
- }
+ GGML_ASSERT(ctx);
llama_batch batch = llama_batch_init(1, 0, 1);
int32_t pos = seq_positions[seq_id];
seq_positions[seq_id]++;
llama_batch_free(batch);
+
return true;
}
bool decode_tokens(const std::map<llama_seq_id, llama_token> & seq_tokens) {
- if (ctx == nullptr) {
- fprintf(stderr, "Error: context not initialized, call setup() first\n");
- return false;
- }
+ GGML_ASSERT(ctx);
llama_batch batch = llama_batch_init(seq_tokens.size(), 0, seq_tokens.size());
update_batch_info(batch);
llama_batch_free(batch);
+
return true;
}
- std::string token_to_piece(llama_token token, bool special) {
+ std::string token_to_piece(llama_token token, bool special) const {
std::string piece;
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
- const int n_chars = llama_token_to_piece(get_vocab(), token, &piece[0], piece.size(), 0, special);
+ const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
if (n_chars < 0) {
piece.resize(-n_chars);
- int check = llama_token_to_piece(get_vocab(), token, &piece[0], piece.size(), 0, special);
+ int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
GGML_ASSERT(check == -n_chars);
- }
- else {
+ } else {
piece.resize(n_chars);
}
return piece;
}
-
- void reset() {
- ctx.reset();
- seq_positions.clear();
- last_batch_info.clear();
- }
-
- const llama_vocab * get_vocab() const {
- return model ? llama_model_get_vocab(model.get()) : nullptr;
- }
-
};
-static void test_backend_greedy_sampling(const backend_cli_args & args) {
- test_model_context test_ctx;
-
+static void test_backend_greedy_sampling(const test_params & params) {
const int seq_id = 0;
struct llama_sampler_chain_params backend_sampler_params = llama_sampler_chain_default_params();
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_greedy());
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
- if (!test_ctx.setup(args, backend_sampler_configs)) {
- return;
- }
+ test_context test_ctx(params, backend_sampler_configs);
if (!test_ctx.decode({{seq_id, "Some"}})) {
GGML_ASSERT(false && "Failed to decode token");
}
}
-static void test_backend_top_k_sampling(const backend_cli_args & args) {
- test_model_context test_ctx;
-
+static void test_backend_top_k_sampling(const test_params & params) {
const int seq_id = 0;
const int32_t k = 8;
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_top_k(k));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
- if (!test_ctx.setup(args, backend_sampler_configs)) {
- return;
- }
+ test_context test_ctx(params, backend_sampler_configs);
if (!test_ctx.decode({{seq_id, "Hello"}})) {
GGML_ASSERT(false && "Failed to decode token");
llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
- const std::string token_str = test_ctx.token_to_piece(token, false);
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
printf("backend top-k hybrid sampling test PASSED\n");
}
-static void test_backend_temp_sampling(const backend_cli_args & args) {
- test_model_context test_ctx;
-
-
+static void test_backend_temp_sampling(const test_params & params) {
{
const float temp_0 = 0.8f;
struct llama_sampler_chain_params backend_chain_params_0 = llama_sampler_chain_default_params();
{ 1, backend_sampler_chain_1.get() }
};
- if (!test_ctx.setup(args, backend_sampler_configs)) {
- return;
- }
+ test_context test_ctx(params, backend_sampler_configs);
if (!test_ctx.decode({{0, "Some where over the"}, {1, "Once upon a"}})) {
GGML_ASSERT(false && "Failed to decode token");
auto test_argmax_temp = [&](float temp) {
printf("\nTesting temperature = %.1f\n", temp);
- test_ctx.reset();
-
int seq_id = 0;
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
{ seq_id, backend_sampler_chain.get() },
};
- if (!test_ctx.setup(args, backend_sampler_configs)) {
- return;
- }
+ test_context test_ctx(params, backend_sampler_configs);
if (!test_ctx.decode({{seq_id, "Once"}})) {
GGML_ASSERT(false && "Failed to decode token");
test_argmax_temp(-1.0f);
printf("backend temp sampling test PASSED\n");
-
}
-static void test_backend_temp_ext_sampling(const backend_cli_args & args) {
- test_model_context test_ctx;
-
+static void test_backend_temp_ext_sampling(const test_params & params) {
{
int seq_id = 0;
const float temp = 0.8f;
{ seq_id, backend_sampler_chain.get() },
};
- if (!test_ctx.setup(args, backend_sampler_configs)) {
- return;
- }
+ test_context test_ctx(params, backend_sampler_configs);
if (!test_ctx.decode({{seq_id, "Once upon a"}})) {
GGML_ASSERT(false && "Failed to decode token");
}
}
- test_ctx.reset();
-
// lambda to testing non-positive temp/delta/exponent values.
auto test_argmax_temp = [&](float temp, float delta, float exponent) {
printf("\nTesting temperature = %.1f, delta = %1.f, exponent = %1.f\n", temp, delta, exponent);
- test_ctx.reset();
-
int seq_id = 0;
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
{ seq_id, backend_sampler_chain.get() },
};
- if (!test_ctx.setup(args, backend_sampler_configs)) {
- return;
- }
+ test_context test_ctx(params, backend_sampler_configs);
if (!test_ctx.decode({{seq_id, "Once"}})) {
GGML_ASSERT(false && "Failed to decode token");
test_argmax_temp(0.8f, 0.0f, 2.0f); // Temperature scaling
printf("backend temp_ext sampling test PASSED\n");
-
}
-static void test_backend_min_p_sampling(const backend_cli_args & args) {
- test_model_context test_ctx;
-
+static void test_backend_min_p_sampling(const test_params & params) {
const int seq_id = 0;
const float p = 0.1;
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_min_p(p, 0));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
- if (!test_ctx.setup(args, backend_sampler_configs)) {
- return;
- }
+ test_context test_ctx(params, backend_sampler_configs);
if (!test_ctx.decode({{seq_id, "Hello"}})) {
GGML_ASSERT(false && "Failed to decode token");
printf("min-p sampling test PASSED\n");
}
-static void test_backend_top_p_sampling(const backend_cli_args & args) {
- test_model_context test_ctx;
-
+static void test_backend_top_p_sampling(const test_params & params) {
const int seq_id = 0;
const float p = 0.9;
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_top_p(p, 0));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
- if (!test_ctx.setup(args, backend_sampler_configs)) {
- return;
- }
+ test_context test_ctx(params, backend_sampler_configs);
if (!test_ctx.decode({{seq_id, "Hello"}})) {
return;
printf("top-p sampling test PASSED\n");
}
-static void test_backend_multi_sequence_sampling(const backend_cli_args & args) {
- test_model_context test_ctx;
-
+static void test_backend_multi_sequence_sampling(const test_params & params) {
struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0));
llama_sampler_chain_add(sampler_chain_0.get(), llama_sampler_init_greedy());
{ 1, sampler_chain_1.get() }
};
- if (!test_ctx.setup(args, backend_sampler_configs)) {
- return;
- }
+ test_context test_ctx(params, backend_sampler_configs);
std::map<llama_seq_id, std::string> prompts = {
{0, "Hello"},
printf("backend multi-sequence sampling test PASSED\n");
}
-static void test_backend_dist_sampling(const backend_cli_args & args) {
- test_model_context test_ctx;
-
+static void test_backend_dist_sampling(const test_params & params) {
const int seq_id = 189;
const int32_t seed = 88;
+
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
- if (!test_ctx.setup(args, backend_sampler_configs)) {
- return;
- }
+ test_context test_ctx(params, backend_sampler_configs);
if (!test_ctx.decode({{seq_id, "Some"}})) {
GGML_ASSERT(false && "Failed to decode token");
printf("backend dist sampling test PASSED\n");
}
-static void test_backend_dist_sampling_and_cpu(const backend_cli_args & args) {
- test_model_context test_ctx;
-
+static void test_backend_dist_sampling_and_cpu(const test_params & params) {
const int seq_id = 0;
const int32_t seed = 88;
+
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
- if (!test_ctx.setup(args, backend_sampler_configs)) {
- return;
- }
+ test_context test_ctx(params, backend_sampler_configs);
if (!test_ctx.decode({{seq_id, "Some"}})) {
GGML_ASSERT(false && "Failed to decode token");
printf("backend dist & cpu sampling test PASSED\n");
}
-static void test_backend_logit_bias_sampling(const backend_cli_args & args) {
- test_model_context test_ctx;
-
- // Calling load_model to ensure vocab is loaded and can be accessed
- if (!test_ctx.load_model(args)) {
- return;
- }
+static void test_backend_logit_bias_sampling(const test_params & params) {
+ const auto * model = params.model.get();
+ const auto * vocab = llama_model_get_vocab(model);
const int seq_id = 0;
- // Create the logit biases vector.
std::vector<llama_logit_bias> logit_bias;
// Get the token for the piece "World".
const std::string piece = "World";
std::vector<llama_token> tokens(16);
- llama_tokenize(test_ctx.get_vocab(), piece.c_str(), piece.size(), tokens.data(), tokens.size(), false, false);
+ llama_tokenize(vocab, piece.c_str(), piece.size(), tokens.data(), tokens.size(), false, false);
+
llama_token bias_token = tokens[0];
- logit_bias.push_back({ bias_token, +100.0f });
+ // TODO: biasing too much here makes the Vulkan sampling fail - should be investigated further
+ // https://github.com/ggml-org/llama.cpp/actions/runs/20894267644/job/60030252675?pr=18753#step:3:23350
+ //logit_bias.push_back({ bias_token, +100.0f });
+ logit_bias.push_back({ bias_token, +10.0f });
+
printf("biasing token piece '%s' -> token id %d\n", piece.c_str(), bias_token);
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_logit_bias(
- llama_vocab_n_tokens(test_ctx.get_vocab()),
+ llama_vocab_n_tokens(vocab),
logit_bias.size(),
logit_bias.data()));
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(88));
{ seq_id, backend_sampler_chain.get() },
};
- if (!test_ctx.setup(args, backend_sampler_configs)) {
- return;
- }
+ test_context test_ctx(params, backend_sampler_configs);
if (!test_ctx.decode({{seq_id, "Hello"}})) {
GGML_ASSERT(false && "Failed to decode token");
}
llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), test_ctx.idx_for_seq(seq_id));
- const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false);
- printf("logit bias sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str());
+ printf("sampled token = %d, expected = %d\n", backend_token, bias_token);
GGML_ASSERT(backend_token == bias_token);
printf("backend logit bias sampling test PASSED\n");
// This test verifies that it is possible to have two different backend sampler,
// one that uses the backend dist sampler, and another that uses CPU dist sampler.
-static void test_backend_mixed_sampling(const backend_cli_args & args) {
- test_model_context test_ctx;
-
+static void test_backend_mixed_sampling(const test_params & params) {
struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0));
llama_sampler_chain_add(sampler_chain_0.get(), llama_sampler_init_dist(88));
{ 1, sampler_chain_1.get() }
};
- if (!test_ctx.setup(args, backend_sampler_configs)) {
- return;
- }
+ test_context test_ctx(params, backend_sampler_configs);
std::map<llama_seq_id, std::string> prompts = {
{0, "Hello"},
printf("backend mixed sampling test PASSED\n");
}
-static void test_backend_set_sampler(const backend_cli_args & args) {
- test_model_context test_ctx;
-
- const int32_t seed = 88;
+static void test_backend_set_sampler(const test_params & params) {
const int seq_id = 0;
+ const int32_t seed = 88;
+
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
- if (!test_ctx.setup(args, backend_sampler_configs)) {
- return;
- }
+ test_context test_ctx(params, backend_sampler_configs);
if (!test_ctx.decode({{seq_id, "Hello"}})) {
GGML_ASSERT(false && "Failed to decode token");
printf("backend set sampler test PASSED\n");
}
-static void test_backend_cpu_mixed_batch(const backend_cli_args & args) {
- test_model_context test_ctx;
-
+static void test_backend_cpu_mixed_batch(const test_params & params) {
// Sequence 0 uses backend sampling
struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0));
};
// We need 2 sequences: seq 0 with backend sampling, seq 1 with CPU sampling
- if (!test_ctx.setup(args, backend_sampler_configs, 2)) {
- return;
- }
+ test_context test_ctx(params, backend_sampler_configs, 2);
std::map<llama_seq_id, std::string> prompts = {
- {0, "Hello"}, // Will use backend sampling
+ {0, "Hello"}, // Will use backend sampling
{1, "Some"} // Will use CPU sampling
};
printf("backend-cpu mixed batch test PASSED\n");
}
-static void test_backend_max_outputs(const backend_cli_args & args) {
- test_model_context test_ctx;
-
+static void test_backend_max_outputs(const test_params & params) {
const int seq_id = 0;
const int32_t seed = 88;
+
llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
- if (!test_ctx.setup(args, backend_sampler_configs)) {
- return;
- }
+ test_context test_ctx(params, backend_sampler_configs);
llama_batch batch = llama_batch_init(512, 0, 1);
std::string prompt = "Hello";
std::vector<llama_token> tokens;
- tokens.push_back(llama_vocab_bos(test_ctx.get_vocab()));
+ tokens.push_back(llama_vocab_bos(test_ctx.vocab));
std::vector<llama_token> prompt_tokens(32);
- int n_tokens = llama_tokenize(test_ctx.get_vocab(), prompt.c_str(), prompt.length(),
+ int n_tokens = llama_tokenize(test_ctx.vocab, prompt.c_str(), prompt.length(),
prompt_tokens.data(), prompt_tokens.size(),
false, false);
for (int i = 0; i < n_tokens; i++) {
}
struct backend_test_case {
- const char * name;
- void (*fn)(const backend_cli_args &);
+ std::string name;
+ void (*fn)(const test_params &);
bool enabled_by_default;
};
{ "top_p", test_backend_top_p_sampling, true },
};
-static backend_cli_args parse_backend_cli(int argc, char ** argv) {
- backend_cli_args out;
+static test_args parse_cli(int argc, char ** argv) {
+ test_args out;
for (int i = 1; i < argc; ++i) {
const char * arg = argv[i];
out.device = arg + 9;
continue;
}
- if (!out.model) {
+ if (out.model.empty()) {
out.model = arg;
continue;
}
exit(EXIT_FAILURE);
}
- if (std::strcmp(out.device, "cpu") != 0 && std::strcmp(out.device, "gpu") != 0) {
- fprintf(stderr, "Invalid device '%s'. Must be 'cpu' or 'gpu'\n", out.device);
+ if (out.device != "cpu" && out.device != "gpu" && out.device != "auto") {
+ fprintf(stderr, "Invalid device '%s'. Must be 'cpu', 'gpu' or 'auto'\n", out.device.c_str());
exit(EXIT_FAILURE);
}
return out;
}
-static std::vector<const backend_test_case *> collect_tests_to_run(const char * requested) {
+static std::vector<const backend_test_case *> collect_tests_to_run(const std::string & requested) {
std::vector<const backend_test_case *> selected;
- if (requested != nullptr) {
+ if (!requested.empty()) {
for (const auto & test : BACKEND_TESTS) {
- if (std::strcmp(test.name, requested) == 0) {
+ if (test.name == requested) {
selected.push_back(&test);
break;
}
}
if (selected.empty()) {
- fprintf(stderr, "Unknown test '%s'. Available tests:\n", requested);
+ fprintf(stderr, "Unknown test '%s'. Available tests:\n", requested.c_str());
for (const auto & test : BACKEND_TESTS) {
- fprintf(stderr, " %s\n", test.name);
+ fprintf(stderr, " %s\n", test.name.c_str());
}
exit(EXIT_FAILURE);
}
return selected;
}
-static void run_tests(const std::vector<const backend_test_case *> & tests, const backend_cli_args & args) {
- for (const auto * test : tests) {
- fprintf(stderr, "\n=== %s ===\n", test->name);
- test->fn(args);
+static void run_tests(const std::vector<const backend_test_case *> & tests, const test_params & args) {
+ for (const auto & test : tests) {
+ fprintf(stderr, "\n=== %s ===\n", test->name.c_str());
+ try {
+ test->fn(args);
+ } catch (const std::exception & e) {
+ fprintf(stderr, "Error running test '%s': %s\n", test->name.c_str(), e.what());
+ exit(EXIT_FAILURE);
+ }
}
}
-
int main(int argc, char ** argv) {
- backend_cli_args args = parse_backend_cli(argc, argv);
+ test_args args = parse_cli(argc, argv);
- if (args.model == nullptr) {
+ if (args.model.empty()) {
args.model = get_model_or_exit(1, argv);
}
- std::ifstream file(args.model);
- if (!file.is_open()) {
- fprintf(stderr, "no model '%s' found\n", args.model);
- return EXIT_FAILURE;
+ {
+ std::ifstream file(args.model);
+ if (!file.is_open()) {
+ fprintf(stderr, "no model '%s' found\n", args.model.c_str());
+ return EXIT_FAILURE;
+ }
}
- fprintf(stderr, "using '%s'\n", args.model);
+ fprintf(stderr, "using '%s'\n", args.model.c_str());
+
+ llama_backend_init();
- ggml_time_init();
+ test_params params = {
+ /*.model =*/ load_model(args),
+ };
const std::vector<const backend_test_case *> tests = collect_tests_to_run(args.test);
if (!tests.empty()) {
- run_tests(tests, args);
+ run_tests(tests, params);
}
return 0;