# include "ggml-rpc.h"
#endif
-#ifdef GGML_USE_CUDA
-# include "ggml-cuda.h"
-#elif defined(GGML_USE_VULKAN)
+#if defined(GGML_USE_VULKAN)
# include "ggml-vulkan.h"
#elif defined(GGML_USE_SYCL)
# include "ggml-sycl.h"
LLM_TENSOR_CLS_OUT,
};
-static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
+static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_NAMES = {
{
LLM_ARCH_LLAMA,
{
return LLM_TENSOR_NAMES.at(arch).at(tensor);
}
- std::string operator()(llm_tensor tensor, const std::string & suffix) const {
+ std::string operator()(llm_tensor tensor, const char * suffix) const {
if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) {
return "__missing__";
}
- return LLM_TENSOR_NAMES.at(arch).at(tensor) + "." + suffix;
+ return std::string(LLM_TENSOR_NAMES.at(arch).at(tensor)) + "." + suffix;
}
std::string operator()(llm_tensor tensor, int bid) const {
if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) {
return "__missing__";
}
- return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor).c_str(), bid);
+ return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor), bid);
}
- std::string operator()(llm_tensor tensor, const std::string & suffix, int bid) const {
+ std::string operator()(llm_tensor tensor, const char * suffix, int bid) const {
if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) {
return "__missing__";
}
- return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor).c_str(), bid) + "." + suffix;
+ return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor), bid) + "." + suffix;
}
- std::string operator()(llm_tensor tensor, const std::string & suffix, int bid, int xid) const {
+ std::string operator()(llm_tensor tensor, const char * suffix, int bid, int xid) const {
if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) {
return "__missing__";
}
- return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor).c_str(), bid, xid) + "." + suffix;
+ return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor), bid, xid) + "." + suffix;
}
};
return piece;
}
-static ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer) {
- ggml_backend_buffer_type_t buft = nullptr;
-
-#if defined(GGML_USE_CUDA)
- // host buffers should only be used when data is expected to be copied to/from the GPU
- if (host_buffer) {
- buft = ggml_backend_cuda_host_buffer_type();
- }
-#elif defined(GGML_USE_SYCL)
- if (host_buffer) {
- buft = ggml_backend_sycl_host_buffer_type();
- }
-#elif defined(GGML_USE_CANN)
- if (host_buffer) {
- buft = ggml_backend_cann_host_buffer_type();
- }
-#elif defined(GGML_USE_CPU_HBM)
- buft = ggml_backend_cpu_hbm_buffer_type();
-#elif defined(GGML_USE_VULKAN)
- if (host_buffer) {
- buft = ggml_backend_vk_host_buffer_type();
- }
-#endif
-
- if (buft == nullptr) {
- buft = ggml_backend_cpu_buffer_type();
- }
- return buft;
-
- GGML_UNUSED(host_buffer);
-}
-
//
// globals
//
-struct llama_state {
- llama_state() {
-#ifdef GGML_USE_METAL
- ggml_backend_metal_log_set_callback(log_callback, log_callback_user_data);
-#elif defined(GGML_USE_CUDA)
- ggml_backend_cuda_log_set_callback(log_callback, log_callback_user_data);
-#elif defined(GGML_USE_CANN)
- ggml_backend_cann_log_set_callback(log_callback, log_callback_user_data);
-#endif
- }
-
- // We save the log callback globally
+struct llama_logger_state {
ggml_log_callback log_callback = llama_log_callback_default;
void * log_callback_user_data = nullptr;
};
-static llama_state g_state;
+static llama_logger_state g_logger_state;
// available llama models
enum e_model {
std::vector<llama_layer> layers;
+ // gguf metadata
+ std::unordered_map<std::string, std::string> gguf_kv;
+
llama_split_mode split_mode;
int main_gpu;
int n_gpu_layers;
- std::vector<std::string> rpc_servers;
+ // list of devices used in this model
+ std::vector<ggml_backend_dev_t> devices;
- // gguf metadata
- std::unordered_map<std::string, std::string> gguf_kv;
+ std::vector<std::string> rpc_servers;
// layer -> buffer type mapping
struct layer_buft {
ggml_free(ctx);
}
for (ggml_backend_buffer_t buf : bufs) {
-#ifdef GGML_USE_CUDA
- if (ggml_backend_buffer_get_type(buf) == ggml_backend_cpu_buffer_type()) {
- ggml_backend_cuda_unregister_host_buffer(ggml_backend_buffer_get_base(buf));
- }
-#endif
ggml_backend_buffer_free(buf);
}
while (!lora_adapters.empty()) {
}
};
-static size_t llama_get_device_count(const llama_model & model) {
- size_t count = 1;
-#if defined(GGML_USE_CUDA)
- count = ggml_backend_cuda_get_device_count();
+static int llama_get_device_count(const llama_model & model) {
+ int count = (int) model.devices.size();
+
+#if defined(GGML_USE_RPC)
+ count += (int) model.rpc_servers.size();
+#endif
+
+#if defined(GGML_USE_METAL)
+ count += 1;
#elif defined(GGML_USE_SYCL)
- count = ggml_backend_sycl_get_device_count();
+ count += ggml_backend_sycl_get_device_count();
#elif defined(GGML_USE_VULKAN)
- count = ggml_backend_vk_get_device_count();
+ count += ggml_backend_vk_get_device_count();
#elif defined(GGML_USE_CANN)
- return ggml_backend_cann_get_device_count();
-#endif
-#if defined(GGML_USE_RPC)
- count += model.rpc_servers.size();
+ count += ggml_backend_cann_get_device_count();
#endif
+
return count;
+
GGML_UNUSED(model);
}
-static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_model & model, int gpu) {
+static ggml_backend_buffer_type_t llama_default_buffer_type_cpu(const llama_model & model, bool host_buffer) {
ggml_backend_buffer_type_t buft = nullptr;
-#ifdef GGML_USE_RPC
- int rpc_count = (int)model.rpc_servers.size();
-#else
- int rpc_count = 0;
+ if (host_buffer) {
+ for (auto * dev : model.devices) {
+ buft = ggml_backend_dev_host_buffer_type(dev);
+ if (buft != nullptr) {
+ break;
+ }
+ }
+ }
+
+#if defined(GGML_USE_SYCL)
+ if (host_buffer) {
+ buft = ggml_backend_sycl_host_buffer_type();
+ }
+#elif defined(GGML_USE_CANN)
+ if (host_buffer) {
+ buft = ggml_backend_cann_host_buffer_type();
+ }
+#elif defined(GGML_USE_CPU_HBM)
+ buft = ggml_backend_cpu_hbm_buffer_type();
+#elif defined(GGML_USE_VULKAN)
+ if (host_buffer) {
+ buft = ggml_backend_vk_host_buffer_type();
+ }
#endif
- int local_gpu = gpu - rpc_count;
+
+ if (buft == nullptr) {
+ buft = ggml_backend_cpu_buffer_type();
+ }
+ return buft;
+
+ GGML_UNUSED(host_buffer);
+}
+
+static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_model & model, int device) {
+ ggml_backend_buffer_type_t buft = nullptr;
+
#if defined(GGML_USE_RPC)
- if (gpu < rpc_count) {
- const char * endpoint = model.rpc_servers[gpu].c_str();
+ int rpc_count = (int)model.rpc_servers.size();
+ if (device < rpc_count) {
+ const char * endpoint = model.rpc_servers[device].c_str();
return ggml_backend_rpc_buffer_type(endpoint);
}
+ device -= rpc_count;
#endif
+
+ if (device < (int)model.devices.size()) {
+ return ggml_backend_dev_buffer_type(model.devices[device]);
+ }
+ device -= (int)model.devices.size();
+
#if defined(GGML_USE_METAL)
buft = ggml_backend_metal_buffer_type();
-#elif defined(GGML_USE_CUDA)
- buft = ggml_backend_cuda_buffer_type(local_gpu);
#elif defined(GGML_USE_VULKAN)
- buft = ggml_backend_vk_buffer_type(local_gpu);
+ buft = ggml_backend_vk_buffer_type(device);
#elif defined(GGML_USE_SYCL)
- buft = ggml_backend_sycl_buffer_type(local_gpu);
+ buft = ggml_backend_sycl_buffer_type(device);
#elif defined(GGML_USE_KOMPUTE)
- buft = ggml_backend_kompute_buffer_type(local_gpu);
- if (buft == nullptr) {
- LLAMA_LOG_WARN("%s: cannot use GPU %d, check `vulkaninfo --summary`\n", __func__, local_gpu);
- }
+ buft = ggml_backend_kompute_buffer_type(device);
#elif defined(GGML_USE_CANN)
- buft = ggml_backend_cann_buffer_type(local_gpu);
+ buft = ggml_backend_cann_buffer_type(device);
#endif
if (buft == nullptr) {
- buft = llama_default_buffer_type_cpu(true);
+ buft = llama_default_buffer_type_cpu(model, true);
}
return buft;
+
GGML_UNUSED(model);
- GGML_UNUSED(local_gpu);
}
static ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama_model & model, int fallback_gpu, const float * tensor_split) {
ggml_backend_buffer_type_t buft = nullptr;
-#ifdef GGML_USE_CUDA
- if (ggml_backend_cuda_get_device_count() > 1) {
- buft = ggml_backend_cuda_split_buffer_type(tensor_split);
+ // find a backend that supports split buffers
+ for (size_t i = 0; i < ggml_backend_reg_count(); ++i) {
+ ggml_backend_reg_t reg = ggml_backend_reg_get(i);
+
+ auto ggml_backend_split_buffer_type_fn = (ggml_backend_split_buffer_type_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_split_buffer_type");
+ if (ggml_backend_split_buffer_type_fn) {
+ buft = ggml_backend_split_buffer_type_fn(tensor_split);
+ if (buft != nullptr) {
+ break;
+ }
+ }
}
-#endif
#ifdef GGML_USE_SYCL
if (ggml_backend_sycl_get_device_count() > 1) {
}
static size_t llama_get_device_memory(const llama_model & model, int device) {
-#ifdef GGML_USE_RPC
- int rpc_count = (int)model.rpc_servers.size();
-#else
- int rpc_count = 0;
-#endif
- int local_device = device - rpc_count;
#if defined(GGML_USE_RPC)
+ int rpc_count = (int)model.rpc_servers.size();
if (device < rpc_count) {
size_t total;
size_t free;
ggml_backend_rpc_get_device_memory(endpoint, &free, &total);
return free;
}
+ device = device - rpc_count;
#endif
-#if defined(GGML_USE_CUDA)
- size_t total;
- size_t free;
- ggml_backend_cuda_get_device_memory(local_device, &free, &total);
- return free;
-#elif defined(GGML_USE_SYCL)
+
+ if (device < (int)model.devices.size()) {
+ ggml_backend_dev_t dev = model.devices[device];
+ size_t total;
+ size_t free;
+ ggml_backend_dev_memory(dev, &free, &total);
+ return free;
+ }
+
+#if defined(GGML_USE_SYCL)
size_t total;
size_t free;
- ggml_backend_sycl_get_device_memory(local_device, &free, &total);
+ ggml_backend_sycl_get_device_memory(device, &free, &total);
return free;
#elif defined(GGML_USE_VULKAN)
size_t total;
size_t free;
- ggml_backend_vk_get_device_memory(local_device, &free, &total);
+ ggml_backend_vk_get_device_memory(device, &free, &total);
return free;
#elif defined(GGML_USE_CANN)
size_t total;
size_t free;
- ggml_backend_cann_get_device_memory(local_device, &free, &total);
+ ggml_backend_cann_get_device_memory(device, &free, &total);
return free;
#else
return 1;
#endif
GGML_UNUSED(model);
- GGML_UNUSED(local_device);
+ GGML_UNUSED(device);
}
//
buft_layer_count[model.buft_layer[i].buft]++;
}
} else {
- buft_layer_count[llama_default_buffer_type_cpu(true)] = n_layer;
+ buft_layer_count[llama_default_buffer_type_cpu(model, true)] = n_layer;
}
// create a context for each buffer type
static const int TENSOR_NOT_REQUIRED = 1;
static const int TENSOR_DUPLICATED = 2;
- struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector<int64_t> & ne, int flags = 0) {
+ struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list<int64_t> & ne, int flags = 0) {
const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED));
if (cur == NULL) {
return create_tensor_for(ctx, cur, flags & TENSOR_DUPLICATED);
}
- struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::vector<int64_t> & ne, size_t offset, bool required = true) {
+ struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list<int64_t> & ne, size_t offset, bool required = true) {
const struct ggml_tensor * cur = check_tensor_dims(name, ne, required);
if (cur == NULL) {
std::array<int64_t, GGML_MAX_DIMS> dims;
for (size_t i = 0; i < GGML_MAX_DIMS; ++i) {
- dims[i] = i < ne.size() ? ne[i] : 1;
+ dims[i] = i < ne.size() ? ne.begin()[i] : 1;
}
struct ggml_tensor * tensor = ggml_view_4d(ctx, base,
// Returns false if cancelled by progress_callback
bool load_all_data(
struct ggml_context * ctx,
- llama_buf_map & bufs_mmap,
+ llama_buf_map & bufs,
llama_mlocks * lmlocks,
llama_progress_callback progress_callback,
void * progress_callback_user_data) {
std::vector<no_init<uint8_t>> read_buf;
std::vector<std::future<std::pair<ggml_tensor *, bool>>> validation_result;
-#if defined(GGML_USE_CUDA)
// 4 staging buffers for async uploads, each sized 1MB seems to be a good default for single NVMe drives.
// NVMe raid configurations might require more / larger buffers.
constexpr size_t n_buffers = 4;
constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB
std::vector<ggml_backend_buffer_t> host_buffers;
- std::vector<void*> host_ptrs;
std::vector<ggml_backend_event_t> events;
+ std::vector<void *> host_ptrs;
size_t buffer_idx = 0; // buffer to use for async loads
-
- ggml_backend_t cuda_backend = nullptr;
- if (!use_mmap && !check_tensors) {
+ ggml_backend_t upload_backend = [&](const char * fn) -> ggml_backend_t {
+ if (use_mmap || check_tensors) {
+ return nullptr;
+ }
// When not using mmaped io use async uploads from pinned memory to GPU memory.
- // First determine if the CUDA backend is active, and if so, determine the device ID.
- ggml_backend_buffer_t buf = bufs_mmap.count(0) ? bufs_mmap.at(0) : nullptr;
- if (buf) {
- ggml_backend_buffer_type_t buffer_type = ggml_backend_buffer_get_type(buf);
- for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) {
- auto * cuda_buffer_type = ggml_backend_cuda_buffer_type(i);
- if (buffer_type == cuda_buffer_type) {
- cuda_backend = ggml_backend_cuda_init(i);
- break;
- }
- }
+ // First determine if the backend supports the necessary features for async uploads.
+ auto * buf = bufs.count(0) ? bufs.at(0) : nullptr;
+ if (!buf) {
+ LLAMA_LOG_DEBUG("%s: no buffer found for async uploads\n", fn);
+ return nullptr;
+ }
+
+ auto * buft = ggml_backend_buffer_get_type(buf);
+ auto * dev = ggml_backend_buft_get_device(buft);
+ if (!dev) {
+ LLAMA_LOG_DEBUG("%s: no device found for buffer type %s for async uploads\n", fn,
+ ggml_backend_buft_name(buft));
+ return nullptr;
+ }
+
+ if (buft != ggml_backend_dev_buffer_type(dev)) {
+ LLAMA_LOG_DEBUG("%s: buffer type %s is not the default buffer type for device %s for async uploads\n", fn,
+ ggml_backend_buft_name(buft), ggml_backend_dev_name(dev));
+ return nullptr;
+ }
+
+ ggml_backend_dev_props props;
+ ggml_backend_dev_get_props(dev, &props);
+ if (!props.caps.async || !props.caps.host_buffer || !props.caps.events) {
+ LLAMA_LOG_DEBUG("%s: device %s does not support async, host buffers or events\n", fn,
+ ggml_backend_dev_name(dev));
+ return nullptr;
+ }
+
+ auto * host_buft = ggml_backend_dev_host_buffer_type(dev);
+ if (!host_buft) {
+ LLAMA_LOG_DEBUG("%s: no host buffer type found for device %s\n", fn,
+ ggml_backend_dev_name(dev));
+ return nullptr;
}
- // If the cuda backend is active create pinned memory buffers and events for synchronisation.
- if (cuda_backend) {
- for (size_t idx = 0; idx < n_buffers; ++idx) {
- host_buffers.emplace_back(ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buffer_size));
- host_ptrs.emplace_back(ggml_backend_buffer_get_base(host_buffers[idx]));
- events.emplace_back(ggml_backend_event_new(cuda_backend));
+ // If the backend is supported, create pinned memory buffers and events for synchronisation.
+ for (size_t idx = 0; idx < n_buffers; ++idx) {
+ auto * buf = ggml_backend_buft_alloc_buffer(host_buft, buffer_size);
+ if (!buf) {
+ LLAMA_LOG_DEBUG("%s: failed to allocate host buffer for async uploads for device %s\n", fn,
+ ggml_backend_dev_name(dev));
+ return nullptr;
+ }
+
+ host_buffers.emplace_back(buf);
+ host_ptrs.emplace_back(ggml_backend_buffer_get_base(buf));
+
+ auto * event = ggml_backend_event_new(dev);
+ if (!event) {
+ LLAMA_LOG_DEBUG("%s: failed to create event for async uploads for device %s\n", fn,
+ ggml_backend_dev_name(dev));
+ return nullptr;
}
+
+ events.emplace_back(event);
}
+
+ ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
+ if (!backend) {
+ LLAMA_LOG_DEBUG("%s: failed to initialize backend for device %s for async uploads\n", fn,
+ ggml_backend_dev_name(dev));
+ return nullptr;
+ }
+
+ return backend;
+ }(__func__);
+
+ if (upload_backend) {
+ LLAMA_LOG_DEBUG("%s: using async uploads for device %s, buffer type %s, backend %s\n", __func__,
+ ggml_backend_dev_name(ggml_backend_get_device(upload_backend)),
+ ggml_backend_buft_name(ggml_backend_buffer_get_type(bufs.at(0))),
+ ggml_backend_name(upload_backend));
}
-#endif
for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur != NULL; cur = ggml_get_next_tensor(ctx, cur)) {
const auto * weight = get_weight(ggml_get_name(cur));
if (use_mmap) {
const auto & mapping = mappings.at(weight->idx);
ggml_backend_buffer_t buf_mmap = nullptr;
- if (bufs_mmap.count(weight->idx)) {
- buf_mmap = bufs_mmap.at(weight->idx);
+ if (bufs.count(weight->idx)) {
+ buf_mmap = bufs.at(weight->idx);
}
uint8_t * data = (uint8_t *) mapping->addr + weight->offs;
}));
}
} else {
-#if defined(GGML_USE_CUDA)
- // If cuda_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU.
- if (cuda_backend) {
+ // If upload_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU.
+ if (upload_backend) {
file->seek(weight->offs, SEEK_SET);
size_t bytes_read = 0;
ggml_backend_event_synchronize(events[buffer_idx]);
file->read_raw(host_ptrs[buffer_idx], read_iteration);
- ggml_backend_tensor_set_async(cuda_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration);
- ggml_backend_event_record(events[buffer_idx]);
+ ggml_backend_tensor_set_async(upload_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration);
+ ggml_backend_event_record(events[buffer_idx], upload_backend);
bytes_read += read_iteration;
++buffer_idx;
buffer_idx %= n_buffers;
}
- }
- else
-#endif
- {
+ } else {
read_buf.resize(n_size);
file->seek(weight->offs, SEEK_SET);
file->read_raw(read_buf.data(), n_size);
size_done += n_size;
}
-#if defined(GGML_USE_CUDA)
- // free temporary resources used for async cuda uploads
- if (cuda_backend) {
- for (size_t idx = 0; idx < n_buffers;++idx) {
- ggml_backend_event_synchronize(events[idx]);
- ggml_backend_event_free(events[idx]);
- ggml_backend_buffer_free(host_buffers[idx]);
- }
- ggml_backend_free(cuda_backend);
+ // free temporary resources used for async uploads
+ for (auto * event : events) {
+ ggml_backend_event_synchronize(event);
+ ggml_backend_event_free(event);
}
-#endif
+ for (auto * buf : host_buffers) {
+ ggml_backend_buffer_free(buf);
+ }
+ ggml_backend_free(upload_backend);
// check validation results
bool validation_failed = false;
void * progress_callback_user_data) {
auto & hparams = model.hparams;
+ // check if the value of main_gpu is valid
+ if (llama_get_device_count(model) > 0 &&
+ split_mode != LLAMA_SPLIT_MODE_LAYER &&
+ (main_gpu < 0 || main_gpu >= llama_get_device_count(model))) {
+ throw std::runtime_error(format("invalid value for main_gpu: %d (available devices: %d)", main_gpu, llama_get_device_count(model)));
+ }
+
model.split_mode = split_mode;
model.main_gpu = main_gpu;
model.n_gpu_layers = n_gpu_layers;
bool use_mmap_buffer = true;
// there is very little benefit to offloading the input layer, so always keep it on the CPU
- model.buft_input = llama_default_buffer_type_cpu(true);
+ model.buft_input = llama_default_buffer_type_cpu(model, true);
//model.buft_input = llama_default_buffer_type_offload(main_gpu);
model.buft_layer.resize(n_layer);
// assign cpu layers
for (int i = 0; i < i_gpu_start; ++i) {
- model.buft_layer[i] = llama_default_buffer_type_cpu(true);
+ model.buft_layer[i] = llama_default_buffer_type_cpu(model, true);
}
if (split_mode == LLAMA_SPLIT_MODE_LAYER) {
int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(act_gpu_layers - 1)/act_gpu_layers) - splits.begin();
model.buft_output = llama_default_buffer_type_offload(model, layer_gpu);
} else {
- model.buft_output = llama_default_buffer_type_cpu(true);
+ model.buft_output = llama_default_buffer_type_cpu(model, true);
}
} else {
ggml_backend_buffer_type_t split_buft;
llama_default_buffer_type_offload(model, main_gpu)
};
} else {
- model.buft_output = llama_default_buffer_type_cpu(true);
+ model.buft_output = llama_default_buffer_type_cpu(model, true);
}
}
// only the mmap region containing the tensors in the model is mapped to the backend buffer
// this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, then we could just use metal for all layers
// this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size
- if (ml.use_mmap && use_mmap_buffer && buft == llama_default_buffer_type_cpu(true)) {
+ if (ml.use_mmap && use_mmap_buffer && buft == llama_default_buffer_type_cpu(model, true)) {
for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
void * addr = nullptr;
size_t first, last;
}
model.bufs.push_back(buf);
bufs.emplace(idx, buf);
-#ifdef GGML_USE_CUDA
- if (n_layer >= n_gpu_layers) {
- ggml_backend_cuda_register_host_buffer(
- ggml_backend_buffer_get_base(buf),
- ggml_backend_buffer_get_size(buf));
- }
-#endif
}
}
#ifdef GGML_USE_METAL
lctx.embd = nullptr;
}
- lctx.buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), new_size);
+ lctx.buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(lctx.model, true), new_size);
if (lctx.buf_output == nullptr) {
LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
return 0;
}
size_t llama_max_devices(void) {
-#if defined(GGML_USE_RPC)
- return GGML_RPC_MAX_SERVERS;
-#elif defined(GGML_USE_METAL)
- return 1;
-#elif defined(GGML_USE_CUDA)
- return GGML_CUDA_MAX_DEVICES;
-#elif defined(GGML_USE_SYCL)
- return GGML_SYCL_MAX_DEVICES;
-#elif defined(GGML_USE_VULKAN)
- return GGML_VK_MAX_DEVICES;
-#elif defined(GGML_USE_CANN)
- return GGML_CANN_MAX_DEVICES;
-#else
- return 1;
-#endif
+ return 16;
}
bool llama_supports_mmap(void) {
}
bool llama_supports_gpu_offload(void) {
-#if defined(GGML_USE_CUDA) || defined(GGML_USE_METAL) || defined(GGML_USE_VULKAN) || \
+#if defined(GGML_USE_METAL) || defined(GGML_USE_VULKAN) || \
defined(GGML_USE_SYCL) || defined(GGML_USE_KOMPUTE) || defined(GGML_USE_RPC)
// Defined when llama.cpp is compiled with support for offloading model layers to GPU.
return true;
#else
- return false;
+ return ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU) != nullptr ||
+ ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU_FULL) != nullptr;
#endif
}
return true;
};
}
+
if (params.rpc_servers != nullptr && params.rpc_servers[0] != '\0') {
// split the servers set them into model->rpc_servers
std::string servers(params.rpc_servers);
size_t pos = 0;
- while ((pos = servers.find(",")) != std::string::npos) {
+ while ((pos = servers.find(',')) != std::string::npos) {
std::string server = servers.substr(0, pos);
model->rpc_servers.push_back(server);
servers.erase(0, pos + 1);
}
model->rpc_servers.push_back(servers);
}
+
+ // create list of devices to use with this model
+ // currently, we use all available devices
+ // TODO: rework API to give user more control over device selection
+ for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
+ ggml_backend_dev_t dev = ggml_backend_dev_get(i);
+ // skip the CPU backend since it is handled separately
+ if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_CPU_FULL) {
+ model->devices.push_back(dev);
+ }
+ }
+
int status = llama_model_load(path_model, *model, params);
GGML_ASSERT(status <= 0);
if (status < 0) {
if (!hparams.vocab_only) {
// initialize backends
+ int main_gpu = model->main_gpu;
+
+ // with registry
+ if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
+ if (main_gpu >= 0 && main_gpu < (int)model->devices.size()) {
+ ggml_backend_dev_t main_dev = model->devices[main_gpu];
+ ggml_backend_t backend = ggml_backend_dev_init(main_dev, nullptr);
+ if (backend == nullptr) {
+ LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(main_dev));
+ llama_free(ctx);
+ return nullptr;
+ }
+ ctx->backends.push_back(backend);
+ }
+ } else {
+ // LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU
+ for (auto * dev : model->devices) {
+ ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
+ if (backend == nullptr) {
+ LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
+ llama_free(ctx);
+ return nullptr;
+ }
+ ctx->backends.push_back(backend);
+ }
+ }
+ if (main_gpu >= (int)model->devices.size()) {
+ main_gpu -= (int)model->devices.size();
+ }
+
#if defined(GGML_USE_RPC)
if (model->n_gpu_layers > 0) {
for (const auto & endpoint : model->rpc_servers) {
ctx->backends.push_back(backend);
}
}
+ if (main_gpu >= (int)model->rpc_servers.size()) {
+ main_gpu -= (int)model->rpc_servers.size();
+ }
#endif
#if defined(GGML_USE_METAL)
}
ctx->backends.push_back(ctx->backend_metal);
}
-#elif defined(GGML_USE_CUDA)
- if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
- // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
- ggml_backend_t backend = ggml_backend_cuda_init(model->main_gpu);
- if (backend == nullptr) {
- LLAMA_LOG_ERROR("%s: failed to initialize CUDA%d backend\n", __func__, model->main_gpu);
- llama_free(ctx);
- return nullptr;
- }
- ctx->backends.push_back(backend);
- } else {
- // LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU
- for (int device = 0; device < ggml_backend_cuda_get_device_count(); ++device) {
- ggml_backend_t backend = ggml_backend_cuda_init(device);
- if (backend == nullptr) {
- LLAMA_LOG_ERROR("%s: failed to initialize CUDA%d backend\n", __func__, device);
- llama_free(ctx);
- return nullptr;
- }
- ctx->backends.push_back(backend);
- }
- }
#elif defined(GGML_USE_VULKAN)
if (model->split_mode == LLAMA_SPLIT_MODE_ROW) {
LLAMA_LOG_ERROR("%s: Row split not supported. Failed to initialize Vulkan backend\n", __func__);
return nullptr;
}
if (model->split_mode == LLAMA_SPLIT_MODE_NONE) {
- ggml_backend_t backend = ggml_backend_vk_init(model->main_gpu);
+ ggml_backend_t backend = ggml_backend_vk_init(main_gpu);
if (backend == nullptr) {
LLAMA_LOG_ERROR("%s: failed to initialize Vulkan backend\n", __func__);
llama_free(ctx);
#elif defined(GGML_USE_SYCL)
// with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
- ggml_backend_t backend = ggml_backend_sycl_init(model->main_gpu);
+ ggml_backend_t backend = ggml_backend_sycl_init(main_gpu);
if (backend == nullptr) {
- LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d backend\n", __func__, model->main_gpu);
+ LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d backend\n", __func__, main_gpu);
llama_free(ctx);
return nullptr;
}
}
#elif defined(GGML_USE_KOMPUTE)
if (model->n_gpu_layers > 0) {
- auto * backend = ggml_backend_kompute_init(model->main_gpu);
+ auto * backend = ggml_backend_kompute_init(main_gpu);
if (backend == nullptr) {
LLAMA_LOG_ERROR("%s: failed to initialize Kompute backend\n", __func__);
llama_free(ctx);
ctx->backends.push_back(backend);
}
#elif defined(GGML_USE_CANN)
- // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
- // TODO: ggml_backend_cann is not support split tensor now, just leave code here.
- if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
- ggml_backend_t backend = ggml_backend_cann_init(model->main_gpu);
- if (backend == nullptr) {
- LLAMA_LOG_ERROR("%s: failed to initialize CANN%d backend\n", __func__, model->main_gpu);
- llama_free(ctx);
- return nullptr;
- }
- ctx->backends.push_back(backend);
- } else {
- // LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU
- // TODO: currently, CANN can't use multi-gpus, just leave code here for further cann version.
- for (int32_t device = 0; device < ggml_backend_cann_get_device_count(); ++device) {
- ggml_backend_t backend = ggml_backend_cann_init(device);
+ // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
+ // TODO: ggml_backend_cann is not support split tensor now, just leave code here.
+ if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
+ ggml_backend_t backend = ggml_backend_cann_init(main_gpu);
if (backend == nullptr) {
- LLAMA_LOG_ERROR("%s: failed to initialize CANN%d backend\n", __func__, device);
+ LLAMA_LOG_ERROR("%s: failed to initialize CANN%d backend\n", __func__, main_gpu);
llama_free(ctx);
return nullptr;
}
ctx->backends.push_back(backend);
+ } else {
+ // LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU
+ // TODO: currently, CANN can't use multi-gpus, just leave code here for further cann version.
+ for (int32_t device = 0; device < ggml_backend_cann_get_device_count(); ++device) {
+ ggml_backend_t backend = ggml_backend_cann_init(device);
+ if (backend == nullptr) {
+ LLAMA_LOG_ERROR("%s: failed to initialize CANN%d backend\n", __func__, device);
+ llama_free(ctx);
+ return nullptr;
+ }
+ ctx->backends.push_back(backend);
+ }
}
- }
#endif
#ifdef GGML_USE_BLAS
for (auto * backend : ctx->backends) {
if (ggml_backend_is_cpu(backend)) {
// use host buffers for the CPU backend compute buffer
- backend_buft.push_back(llama_default_buffer_type_cpu(true));
+ backend_buft.push_back(llama_default_buffer_type_cpu(*model, true));
} else {
backend_buft.push_back(ggml_backend_get_default_buffer_type(backend));
}
// buffer used to store the computation graph and the tensor meta data
ctx->buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
+ // TODO: move these checks to ggml_backend_sched
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
bool pipeline_parallel =
llama_get_device_count(*model) > 1 &&
model->n_gpu_layers > (int)model->hparams.n_layer &&
model->split_mode == LLAMA_SPLIT_MODE_LAYER &&
params.offload_kqv;
-#ifndef GGML_USE_CUDA
- // pipeline parallelism requires support for async compute and events
- // currently this is only implemented in the CUDA backend
- pipeline_parallel = false;
-#endif
+
+ // pipeline parallelism requires support for async compute and events in all devices
+ if (pipeline_parallel) {
+ for (auto * backend : ctx->backends) {
+ if (ggml_backend_is_cpu(backend)) {
+ // ignore CPU backend
+ continue;
+ }
+ auto * dev = ggml_backend_get_device(backend);
+ if (!dev) {
+ // backend is using old interface, not supported
+ pipeline_parallel = false;
+ break;
+ }
+ ggml_backend_dev_props props;
+ ggml_backend_dev_get_props(dev, &props);
+ if (!props.caps.async || !props.caps.events) {
+ // device does not support async compute or events
+ pipeline_parallel = false;
+ break;
+ }
+ }
+ }
+
ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), max_nodes, pipeline_parallel);
if (pipeline_parallel) {
}
void llama_log_set(ggml_log_callback log_callback, void * user_data) {
- g_state.log_callback = log_callback ? log_callback : llama_log_callback_default;
- g_state.log_callback_user_data = user_data;
-#ifdef GGML_USE_METAL
- ggml_backend_metal_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
-#elif defined(GGML_USE_CUDA)
- ggml_backend_cuda_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
-#elif defined(GGML_USE_CANN)
- ggml_backend_cann_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
-#endif
+ ggml_log_set(log_callback, user_data);
+ g_logger_state.log_callback = log_callback ? log_callback : llama_log_callback_default;
+ g_logger_state.log_callback_user_data = user_data;
}
static void llama_log_internal_v(ggml_log_level level, const char * format, va_list args) {
char buffer[128];
int len = vsnprintf(buffer, 128, format, args);
if (len < 128) {
- g_state.log_callback(level, buffer, g_state.log_callback_user_data);
+ g_logger_state.log_callback(level, buffer, g_logger_state.log_callback_user_data);
} else {
char * buffer2 = new char[len + 1];
vsnprintf(buffer2, len + 1, format, args_copy);
buffer2[len] = 0;
- g_state.log_callback(level, buffer2, g_state.log_callback_user_data);
+ g_logger_state.log_callback(level, buffer2, g_logger_state.log_callback_user_data);
delete[] buffer2;
}
va_end(args_copy);
+++ /dev/null
-#include "ggml-backend-impl.h"
-#include "ggml-alloc.h"
-#include "ggml-impl.h"
-
-#include <assert.h>
-#include <limits.h>
-#include <stdarg.h>
-#include <stdio.h>
-#include <stdlib.h>
-#include <string.h>
-
-
-#define MAX(a, b) ((a) > (b) ? (a) : (b))
-
-// backend buffer type
-
-const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) {
- return buft->iface.get_name(buft);
-}
-
-GGML_CALL ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
- return buft->iface.alloc_buffer(buft, size);
-}
-
-size_t ggml_backend_buft_get_alignment(ggml_backend_buffer_type_t buft) {
- return buft->iface.get_alignment(buft);
-}
-
-size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) {
- // get_max_size is optional, defaults to SIZE_MAX
- if (buft->iface.get_max_size) {
- return buft->iface.get_max_size(buft);
- }
- return SIZE_MAX;
-}
-
-GGML_CALL size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor) {
- // get_alloc_size is optional, defaults to ggml_nbytes
- if (buft->iface.get_alloc_size) {
- size_t size = buft->iface.get_alloc_size(buft, tensor);
- assert(size >= ggml_nbytes(tensor));
- return size;
- }
- return ggml_nbytes(tensor);
-}
-
-bool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) {
- if (buft->iface.is_host) {
- return buft->iface.is_host(buft);
- }
- return false;
-}
-
-// backend buffer
-
-GGML_CALL ggml_backend_buffer_t ggml_backend_buffer_init(
- ggml_backend_buffer_type_t buft,
- struct ggml_backend_buffer_i iface,
- ggml_backend_buffer_context_t context,
- size_t size) {
- ggml_backend_buffer_t buffer = malloc(sizeof(struct ggml_backend_buffer));
-
- (*buffer) = (struct ggml_backend_buffer) {
- /* .interface = */ iface,
- /* .buft = */ buft,
- /* .context = */ context,
- /* .size = */ size,
- /* .usage = */ GGML_BACKEND_BUFFER_USAGE_ANY
- };
-
- return buffer;
-}
-
-const char * ggml_backend_buffer_name(ggml_backend_buffer_t buffer) {
- return buffer->iface.get_name(buffer);
-}
-
-void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
- if (buffer == NULL) {
- return;
- }
-
- if (buffer->iface.free_buffer != NULL) {
- buffer->iface.free_buffer(buffer);
- }
- free(buffer);
-}
-
-size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
- return buffer->size;
-}
-
-void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
- void * base = buffer->iface.get_base(buffer);
-
- GGML_ASSERT(base != NULL && "backend buffer base cannot be NULL");
-
- return base;
-}
-
-GGML_CALL void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
- // init_tensor is optional
- if (buffer->iface.init_tensor) {
- buffer->iface.init_tensor(buffer, tensor);
- }
-}
-
-size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer) {
- return ggml_backend_buft_get_alignment(ggml_backend_buffer_get_type(buffer));
-}
-
-size_t ggml_backend_buffer_get_max_size(ggml_backend_buffer_t buffer) {
- return ggml_backend_buft_get_max_size(ggml_backend_buffer_get_type(buffer));
-}
-
-size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
- return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_get_type(buffer), tensor);
-}
-
-void ggml_backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
- buffer->iface.clear(buffer, value);
-}
-
-bool ggml_backend_buffer_is_host(ggml_backend_buffer_t buffer) {
- return ggml_backend_buft_is_host(ggml_backend_buffer_get_type(buffer));
-}
-
-void ggml_backend_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) {
- buffer->usage = usage;
-
- // FIXME: add a generic callback to the buffer interface
- if (ggml_backend_buffer_is_multi_buffer(buffer)) {
- ggml_backend_multi_buffer_set_usage(buffer, usage);
- }
-}
-
-enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage(ggml_backend_buffer_t buffer) {
- return buffer->usage;
-}
-
-ggml_backend_buffer_type_t ggml_backend_buffer_get_type(ggml_backend_buffer_t buffer) {
- return buffer->buft;
-}
-
-void ggml_backend_buffer_reset(ggml_backend_buffer_t buffer) {
- if (buffer->iface.reset) {
- buffer->iface.reset(buffer);
- }
-}
-
-bool ggml_backend_buffer_copy_tensor(const struct ggml_tensor * src, struct ggml_tensor * dst) {
- ggml_backend_buffer_t dst_buf = dst->view_src ? dst->view_src->buffer : dst->buffer;
- if (dst_buf->iface.cpy_tensor) {
- return dst_buf->iface.cpy_tensor(dst_buf, src, dst);
- }
- return false;
-}
-
-// backend
-
-ggml_guid_t ggml_backend_guid(ggml_backend_t backend) {
- if (backend == NULL) {
- return NULL;
- }
- return backend->guid;
-}
-
-const char * ggml_backend_name(ggml_backend_t backend) {
- if (backend == NULL) {
- return "NULL";
- }
- return backend->iface.get_name(backend);
-}
-
-void ggml_backend_free(ggml_backend_t backend) {
- if (backend == NULL) {
- return;
- }
-
- backend->iface.free(backend);
-}
-
-ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend) {
- return backend->iface.get_default_buffer_type(backend);
-}
-
-ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size) {
- return ggml_backend_buft_alloc_buffer(ggml_backend_get_default_buffer_type(backend), size);
-}
-
-size_t ggml_backend_get_alignment(ggml_backend_t backend) {
- return ggml_backend_buft_get_alignment(ggml_backend_get_default_buffer_type(backend));
-}
-
-size_t ggml_backend_get_max_size(ggml_backend_t backend) {
- return ggml_backend_buft_get_max_size(ggml_backend_get_default_buffer_type(backend));
-}
-
-void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
- GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
- GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
-
- if (backend->iface.set_tensor_async == NULL) {
- ggml_backend_tensor_set(tensor, data, offset, size);
- } else {
- backend->iface.set_tensor_async(backend, tensor, data, offset, size);
- }
-}
-
-void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
- GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
- GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
-
- if (backend->iface.get_tensor_async == NULL) {
- ggml_backend_tensor_get(tensor, data, offset, size);
- } else {
- backend->iface.get_tensor_async(backend, tensor, data, offset, size);
- }
-}
-
-GGML_CALL void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
- ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
-
- GGML_ASSERT(buf != NULL && "tensor buffer not set");
- GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
- GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
-
- if (!size) {
- return;
- }
-
- buf->iface.set_tensor(buf, tensor, data, offset, size);
-}
-
-GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
- ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
-
- GGML_ASSERT(buf != NULL && "tensor buffer not set");
- GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
- GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
-
- if (!size) {
- return;
- }
-
- buf->iface.get_tensor(buf, tensor, data, offset, size);
-}
-
-GGML_API GGML_CALL void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
- ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
-
- GGML_ASSERT(buf != NULL && "tensor buffer not set");
- GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
- GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
-
- if (!size) {
- return;
- }
-
- GGML_ASSERT(buf->iface.memset_tensor != NULL && "memset not supported by backend buffer");
-
- buf->iface.memset_tensor(buf, tensor, value, offset, size);
-}
-
-void ggml_backend_synchronize(ggml_backend_t backend) {
- if (backend->iface.synchronize == NULL) {
- return;
- }
-
- backend->iface.synchronize(backend);
-}
-
-ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
- GGML_ASSERT(backend->iface.graph_plan_create != NULL);
-
- return backend->iface.graph_plan_create(backend, cgraph);
-}
-
-void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
- GGML_ASSERT(backend->iface.graph_plan_free != NULL);
-
- backend->iface.graph_plan_free(backend, plan);
-}
-
-enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
- GGML_ASSERT(backend->iface.graph_plan_compute != NULL);
-
- return backend->iface.graph_plan_compute(backend, plan);
-}
-
-enum ggml_status ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
- enum ggml_status err = ggml_backend_graph_compute_async(backend, cgraph);
- ggml_backend_synchronize(backend);
- return err;
-}
-
-enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
- return backend->iface.graph_compute(backend, cgraph);
-}
-
-bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
- return backend->iface.supports_op(backend, op);
-}
-
-bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
- return backend->iface.supports_buft(backend, buft);
-}
-
-bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op) {
- if (backend->iface.offload_op != NULL) {
- return backend->iface.offload_op(backend, op);
- }
- return false;
-}
-
-// backend copy
-
-static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
- if (a->type != b->type) {
- return false;
- }
- for (int i = 0; i < GGML_MAX_DIMS; i++) {
- if (a->ne[i] != b->ne[i]) {
- return false;
- }
- if (a->nb[i] != b->nb[i]) {
- return false;
- }
- }
- return true;
-}
-
-void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) {
- GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
-
- if (src == dst) {
- return;
- }
-
- if (ggml_backend_buffer_is_host(src->buffer)) {
- ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
- } else if (ggml_backend_buffer_is_host(dst->buffer)) {
- ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
- } else if (!ggml_backend_buffer_copy_tensor(src, dst)) {
-#ifndef NDEBUG
- fprintf(stderr, "%s: warning: slow copy from %s to %s\n", __func__, ggml_backend_buffer_name(src->buffer), ggml_backend_buffer_name(dst->buffer));
-#endif
- size_t nbytes = ggml_nbytes(src);
- void * data = malloc(nbytes);
- ggml_backend_tensor_get(src, data, 0, nbytes);
- ggml_backend_tensor_set(dst, data, 0, nbytes);
- free(data);
- }
-}
-
-void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst) {
- GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
-
- if (src == dst) {
- return;
- }
-
- if (backend_dst->iface.cpy_tensor_async != NULL) {
- if (backend_dst->iface.cpy_tensor_async(backend_src, backend_dst, src, dst)) {
- return;
- }
- }
-
- // an async copy would normally happen after all the queued operations on both backends are completed
- // to simulate the same behavior, we need to synchronize both backends first, and do a blocking copy
- ggml_backend_synchronize(backend_src);
- ggml_backend_synchronize(backend_dst);
- ggml_backend_tensor_copy(src, dst);
-}
-
-// events
-
-ggml_backend_event_t ggml_backend_event_new(ggml_backend_t backend) {
- if (backend->iface.event_new == NULL) {
- return NULL;
- }
- return backend->iface.event_new(backend);
-}
-
-void ggml_backend_event_free(ggml_backend_event_t event) {
- if (event == NULL) {
- return;
- }
- event->backend->iface.event_free(event);
-}
-
-void ggml_backend_event_record(ggml_backend_event_t event) {
- GGML_ASSERT(event->backend->iface.event_record != NULL);
-
- event->backend->iface.event_record(event);
-}
-
-void ggml_backend_event_synchronize(ggml_backend_event_t event) {
- GGML_ASSERT(event->backend->iface.event_synchronize != NULL);
-
- event->backend->iface.event_synchronize(event);
-}
-
-void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
- GGML_ASSERT(backend->iface.event_wait != NULL);
-
- backend->iface.event_wait(backend, event);
-}
-
-// backend registry
-
-#define GGML_REG_MAX_BACKENDS 64
-
-struct ggml_backend_reg {
- char name[128];
- ggml_backend_init_fn init_fn;
- ggml_backend_buffer_type_t default_buffer_type;
- void * user_data;
-};
-
-static struct ggml_backend_reg ggml_backend_registry[GGML_REG_MAX_BACKENDS];
-static size_t ggml_backend_registry_count = 0;
-
-GGML_CALL static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data);
-
-GGML_CALL static void ggml_backend_registry_init(void) {
- static bool initialized = false;
-
- if (initialized) {
- return;
- }
-
- initialized = true;
-
- ggml_backend_register("CPU", ggml_backend_reg_cpu_init, ggml_backend_cpu_buffer_type(), NULL);
-
- // add forward decls here to avoid including the backend headers
-#ifdef GGML_USE_CUDA
- extern GGML_CALL void ggml_backend_cuda_reg_devices(void);
- ggml_backend_cuda_reg_devices();
-#endif
-
-#ifdef GGML_USE_SYCL
- extern void ggml_backend_sycl_reg_devices(void);
- ggml_backend_sycl_reg_devices();
-#endif
-
-#ifdef GGML_USE_METAL
- extern GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data);
- extern GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
- ggml_backend_register("Metal", ggml_backend_reg_metal_init, ggml_backend_metal_buffer_type(), NULL);
-#endif
-
-#ifdef GGML_USE_VULKAN
- extern GGML_CALL int ggml_backend_vk_reg_devices(void);
- ggml_backend_vk_reg_devices();
-#endif
-
-#ifdef GGML_USE_KOMPUTE
- extern GGML_CALL void ggml_backend_kompute_reg_devices(void);
- ggml_backend_kompute_reg_devices();
-#endif
-
-#ifdef GGML_USE_CANN
- extern GGML_CALL int ggml_backend_cann_reg_devices(void);
- ggml_backend_cann_reg_devices();
-#endif
-}
-
-GGML_CALL void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data) {
- GGML_ASSERT(ggml_backend_registry_count < GGML_REG_MAX_BACKENDS);
-
- size_t id = ggml_backend_registry_count;
-
- ggml_backend_registry[id] = (struct ggml_backend_reg) {
- /* .name = */ {0},
- /* .fn = */ init_fn,
- /* .default_buffer_type = */ default_buffer_type,
- /* .user_data = */ user_data,
- };
-
- snprintf(ggml_backend_registry[id].name, sizeof(ggml_backend_registry[id].name), "%s", name);
-
-#ifndef NDEBUG
- fprintf(stderr, "%s: registered backend %s\n", __func__, name);
-#endif
-
- ggml_backend_registry_count++;
-}
-
-size_t ggml_backend_reg_get_count(void) {
- ggml_backend_registry_init();
-
- return ggml_backend_registry_count;
-}
-
-size_t ggml_backend_reg_find_by_name(const char * name) {
- ggml_backend_registry_init();
-
- for (size_t i = 0; i < ggml_backend_registry_count; i++) {
- // TODO: case insensitive in a portable way
- if (strcmp(ggml_backend_registry[i].name, name) == 0) {
- return i;
- }
- }
-
- // not found
- return SIZE_MAX;
-}
-
-// init from backend:params string
-ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str) {
- ggml_backend_registry_init();
-
- const char * params = strchr(backend_str, ':');
- char backend_name[128];
- if (params == NULL) {
- snprintf(backend_name, sizeof(backend_name), "%s", backend_str);
- params = "";
- } else {
- snprintf(backend_name, sizeof(backend_name), "%.*s", (int)(params - backend_str), backend_str);
- params++;
- }
-
- size_t backend_i = ggml_backend_reg_find_by_name(backend_name);
-
- if (backend_i == SIZE_MAX) {
- fprintf(stderr, "%s: backend %s not found\n", __func__, backend_name);
- return NULL;
- }
-
- return ggml_backend_reg_init_backend(backend_i, params);
-}
-
-const char * ggml_backend_reg_get_name(size_t i) {
- ggml_backend_registry_init();
-
- GGML_ASSERT(i < ggml_backend_registry_count);
- return ggml_backend_registry[i].name;
-}
-
-ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params) {
- ggml_backend_registry_init();
-
- GGML_ASSERT(i < ggml_backend_registry_count);
- return ggml_backend_registry[i].init_fn(params, ggml_backend_registry[i].user_data);
-}
-
-ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i) {
- ggml_backend_registry_init();
-
- GGML_ASSERT(i < ggml_backend_registry_count);
- return ggml_backend_registry[i].default_buffer_type;
-}
-
-ggml_backend_buffer_t ggml_backend_reg_alloc_buffer(size_t i, size_t size) {
- ggml_backend_registry_init();
-
- GGML_ASSERT(i < ggml_backend_registry_count);
- return ggml_backend_buft_alloc_buffer(ggml_backend_registry[i].default_buffer_type, size);
-}
-
-// backend CPU
-
-static const size_t TENSOR_ALIGNMENT = 32; // required for mmap as gguf only guarantees 32-byte alignment
-
-GGML_CALL static const char * ggml_backend_cpu_buffer_name(ggml_backend_buffer_t buffer) {
- return "CPU";
-
- GGML_UNUSED(buffer);
-}
-
-GGML_CALL static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
- uintptr_t data = (uintptr_t)buffer->context;
-
- // align the buffer
- if (data % TENSOR_ALIGNMENT != 0) {
- data = GGML_PAD(data, TENSOR_ALIGNMENT);
- }
-
- return (void *)data;
-}
-
-GGML_CALL static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
- free(buffer->context);
-}
-
-GGML_CALL static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
- memset((char *)tensor->data + offset, value, size);
-
- GGML_UNUSED(buffer);
-}
-
-GGML_CALL static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
- memcpy((char *)tensor->data + offset, data, size);
-
- GGML_UNUSED(buffer);
-}
-
-GGML_CALL static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
- memcpy(data, (const char *)tensor->data + offset, size);
-
- GGML_UNUSED(buffer);
-}
-
-GGML_CALL static bool ggml_backend_cpu_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
- if (ggml_backend_buffer_is_host(src->buffer)) {
- memcpy(dst->data, src->data, ggml_nbytes(src));
- return true;
- }
- return false;
-
- GGML_UNUSED(buffer);
-}
-
-GGML_CALL static void ggml_backend_cpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
- memset(buffer->context, value, buffer->size);
-}
-
-static struct ggml_backend_buffer_i cpu_backend_buffer_i = {
- /* .get_name = */ ggml_backend_cpu_buffer_name,
- /* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer,
- /* .get_base = */ ggml_backend_cpu_buffer_get_base,
- /* .init_tensor = */ NULL, // no initialization required
- /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor,
- /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor,
- /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor,
- /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor,
- /* .clear = */ ggml_backend_cpu_buffer_clear,
- /* .reset = */ NULL,
-};
-
-// for buffers from ptr, free is not called
-static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = {
- /* .get_name = */ ggml_backend_cpu_buffer_name,
- /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
- /* .get_base = */ ggml_backend_cpu_buffer_get_base,
- /* .init_tensor = */ NULL, // no initialization required
- /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor,
- /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor,
- /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor,
- /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor,
- /* .clear = */ ggml_backend_cpu_buffer_clear,
- /* .reset = */ NULL,
-};
-
-GGML_CALL static const char * ggml_backend_cpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
- return "CPU";
-
- GGML_UNUSED(buft);
-}
-
-GGML_CALL static ggml_backend_buffer_t ggml_backend_cpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
- size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned
- void * data = malloc(size); // TODO: use GGML_ALIGNED_MALLOC (move to ggml-impl.h)
- if (data == NULL) {
- fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
- return NULL;
- }
-
- return ggml_backend_buffer_init(buft, cpu_backend_buffer_i, data, size);
-}
-
-GGML_CALL static size_t ggml_backend_cpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
- return TENSOR_ALIGNMENT;
-
- GGML_UNUSED(buft);
-}
-
-GGML_CALL static bool ggml_backend_cpu_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
- return true;
-
- GGML_UNUSED(buft);
-}
-
-GGML_CALL ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void) {
- static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type = {
- /* .iface = */ {
- /* .get_name = */ ggml_backend_cpu_buffer_type_get_name,
- /* .alloc_buffer = */ ggml_backend_cpu_buffer_type_alloc_buffer,
- /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment,
- /* .get_max_size = */ NULL, // defaults to SIZE_MAX
- /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
- /* .is_host = */ ggml_backend_cpu_buffer_type_is_host,
- },
- /* .context = */ NULL,
- };
-
- return &ggml_backend_cpu_buffer_type;
-}
-
-#ifdef GGML_USE_CPU_HBM
-
-// buffer type HBM
-
-#include <hbwmalloc.h>
-
-GGML_CALL static const char * ggml_backend_cpu_hbm_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
- return "CPU_HBM";
-
- GGML_UNUSED(buft);
-}
-
-GGML_CALL static const char * ggml_backend_cpu_hbm_buffer_get_name(ggml_backend_buffer_t buf) {
- return "CPU_HBM";
-
- GGML_UNUSED(buf);
-}
-
-GGML_CALL static void ggml_backend_cpu_hbm_buffer_free_buffer(ggml_backend_buffer_t buffer) {
- hbw_free(buffer->context);
-}
-
-GGML_CALL static ggml_backend_buffer_t ggml_backend_cpu_hbm_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
- //void * ptr = hbw_malloc(size);
- void * ptr;
- int result = hbw_posix_memalign(&ptr, ggml_backend_cpu_buffer_type_get_alignment(buft), size);
- if (result != 0) {
- fprintf(stderr, "failed to allocate HBM buffer of size %zu\n", size);
- return NULL;
- }
-
- ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
- buffer->buft = buft;
- buffer->iface.get_name = ggml_backend_cpu_hbm_buffer_get_name;
- buffer->iface.free_buffer = ggml_backend_cpu_hbm_buffer_free_buffer;
-
- return buffer;
-}
-
-ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) {
- static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_hbm = {
- /* .iface = */ {
- /* .get_name = */ ggml_backend_cpu_hbm_buffer_type_get_name,
- /* .alloc_buffer = */ ggml_backend_cpu_hbm_buffer_type_alloc_buffer,
- /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment,
- /* .get_max_size = */ NULL, // defaults to SIZE_MAX
- /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
- /* .is_host = */ ggml_backend_cpu_buffer_type_is_host,
- },
- /* .context = */ NULL,
- };
-
- return &ggml_backend_cpu_buffer_type_hbm;
-}
-#endif
-
-struct ggml_backend_cpu_context {
- int n_threads;
- ggml_threadpool_t threadpool;
-
- void * work_data;
- size_t work_size;
-
- ggml_abort_callback abort_callback;
- void * abort_callback_data;
-};
-
-GGML_CALL static const char * ggml_backend_cpu_name(ggml_backend_t backend) {
- return "CPU";
-
- GGML_UNUSED(backend);
-}
-
-GGML_CALL static void ggml_backend_cpu_free(ggml_backend_t backend) {
- struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
- free(cpu_ctx->work_data);
- free(cpu_ctx);
- free(backend);
-}
-
-GGML_CALL static ggml_backend_buffer_type_t ggml_backend_cpu_get_default_buffer_type(ggml_backend_t backend) {
- return ggml_backend_cpu_buffer_type();
-
- GGML_UNUSED(backend);
-}
-
-struct ggml_backend_plan_cpu {
- struct ggml_cplan cplan;
- struct ggml_cgraph cgraph;
-};
-
-GGML_CALL static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph * cgraph) {
- struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
-
- struct ggml_backend_plan_cpu * cpu_plan = malloc(sizeof(struct ggml_backend_plan_cpu));
-
- cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool);
- cpu_plan->cgraph = *cgraph; // FIXME: deep copy
-
- if (cpu_plan->cplan.work_size > 0) {
- cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size);
- if (cpu_plan->cplan.work_data == NULL) {
- free(cpu_plan);
- return NULL;
- }
- }
-
- cpu_plan->cplan.abort_callback = cpu_ctx->abort_callback;
- cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data;
-
- return cpu_plan;
-}
-
-GGML_CALL static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
- struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
-
- free(cpu_plan->cplan.work_data);
- free(cpu_plan);
-
- GGML_UNUSED(backend);
-}
-
-GGML_CALL static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
- struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
-
- return ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
-
- GGML_UNUSED(backend);
-}
-
-GGML_CALL static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
- struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
-
- struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool);
-
- if (cpu_ctx->work_size < cplan.work_size) {
- free(cpu_ctx->work_data);
- cpu_ctx->work_data = malloc(cplan.work_size);
- if (cpu_ctx->work_data == NULL) {
- cpu_ctx->work_size = 0;
- return GGML_STATUS_ALLOC_FAILED;
- }
- cpu_ctx->work_size = cplan.work_size;
- }
- cplan.work_data = cpu_ctx->work_data;
-
- cplan.abort_callback = cpu_ctx->abort_callback;
- cplan.abort_callback_data = cpu_ctx->abort_callback_data;
-
- return ggml_graph_compute(cgraph, &cplan);
-}
-
-GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
- switch (op->op) {
- case GGML_OP_CPY:
- return
- op->type != GGML_TYPE_IQ2_XXS &&
- op->type != GGML_TYPE_IQ2_XS &&
- op->type != GGML_TYPE_IQ1_S &&
- op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
- case GGML_OP_MUL_MAT:
- return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type;
- case GGML_OP_ROPE_BACK:
- return op->src[2] == NULL && (op->op_params[2] & 4) == 0;
- case GGML_OP_IM2COL_BACK:
- return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
- default:
- return true;
- }
-
- GGML_UNUSED(backend);
-}
-
-GGML_CALL static bool ggml_backend_cpu_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
- return ggml_backend_buft_is_host(buft);
-
- GGML_UNUSED(backend);
-}
-
-static struct ggml_backend_i cpu_backend_i = {
- /* .get_name = */ ggml_backend_cpu_name,
- /* .free = */ ggml_backend_cpu_free,
- /* .get_default_buffer_type = */ ggml_backend_cpu_get_default_buffer_type,
- /* .set_tensor_async = */ NULL,
- /* .get_tensor_async = */ NULL,
- /* .cpy_tensor_async = */ NULL,
- /* .synchronize = */ NULL,
- /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create,
- /* .graph_plan_free = */ ggml_backend_cpu_graph_plan_free,
- /* .graph_plan_update = */ NULL,
- /* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute,
- /* .graph_compute = */ ggml_backend_cpu_graph_compute,
- /* .supports_op = */ ggml_backend_cpu_supports_op,
- /* .supports_buft = */ ggml_backend_cpu_supports_buft,
- /* .offload_op = */ NULL,
- /* .event_new = */ NULL,
- /* .event_free = */ NULL,
- /* .event_record = */ NULL,
- /* .event_wait = */ NULL,
- /* .event_synchronize = */ NULL,
-};
-
-static ggml_guid_t ggml_backend_cpu_guid(void) {
- static ggml_guid guid = { 0xaa, 0x67, 0xc7, 0x43, 0x96, 0xe6, 0xa3, 0x8a, 0xe3, 0xaf, 0xea, 0x92, 0x36, 0xbc, 0xfc, 0x89 };
- return &guid;
-}
-
-ggml_backend_t ggml_backend_cpu_init(void) {
- struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context));
- if (ctx == NULL) {
- return NULL;
- }
-
- ctx->n_threads = GGML_DEFAULT_N_THREADS;
- ctx->threadpool = NULL;
- ctx->work_data = NULL;
- ctx->work_size = 0;
- ctx->abort_callback = NULL;
- ctx->abort_callback_data = NULL;
-
- ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend));
- if (cpu_backend == NULL) {
- free(ctx);
- return NULL;
- }
-
- *cpu_backend = (struct ggml_backend) {
- /* .guid = */ ggml_backend_cpu_guid(),
- /* .interface = */ cpu_backend_i,
- /* .context = */ ctx
- };
- return cpu_backend;
-}
-
-GGML_CALL bool ggml_backend_is_cpu(ggml_backend_t backend) {
- return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cpu_guid());
-}
-
-void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
- GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
-
- struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
- ctx->n_threads = n_threads;
-}
-
-void ggml_backend_cpu_set_threadpool(ggml_backend_t backend_cpu, ggml_threadpool_t threadpool) {
- GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
-
- struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
-
- if (ctx->threadpool && ctx->threadpool != threadpool) {
- // already had a different threadpool, pause/suspend it before switching
- ggml_threadpool_pause(ctx->threadpool);
- }
- ctx->threadpool = threadpool;
-}
-
-void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data) {
- GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
-
- struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
- ctx->abort_callback = abort_callback;
- ctx->abort_callback_data = abort_callback_data;
-}
-
-GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
- GGML_ASSERT((uintptr_t)ptr % TENSOR_ALIGNMENT == 0 && "buffer pointer must be aligned");
- return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size);
-}
-
-GGML_CALL static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data) {
- return ggml_backend_cpu_init();
-
- GGML_UNUSED(params);
- GGML_UNUSED(user_data);
-}
-
-// multi-buffer buffer
-
-struct ggml_backend_multi_buffer_context {
- ggml_backend_buffer_t * buffers;
- size_t n_buffers;
-};
-
-typedef struct ggml_backend_multi_buffer_context * ggml_backend_multi_buffer_context_t;
-
-GGML_CALL static const char * ggml_backend_multi_buffer_get_name(ggml_backend_buffer_t buffer) {
- ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) buffer->context;
-
- return ctx->buffers[0]->iface.get_name(ctx->buffers[0]);
-}
-
-GGML_CALL static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer) {
- ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) buffer->context;
- for (size_t i = 0; i < ctx->n_buffers; i++) {
- ggml_backend_buffer_free(ctx->buffers[i]);
- }
-
- free(ctx->buffers);
- free(ctx);
-}
-
-GGML_CALL static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
- ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) buffer->context;
- for (size_t i = 0; i < ctx->n_buffers; i++) {
- ggml_backend_buffer_clear(ctx->buffers[i], value);
- }
-}
-
-static struct ggml_backend_buffer_i ggml_backend_multi_buffer_context_interface(void) {
- static struct ggml_backend_buffer_i multi_backend_buffer_i = {
- /* .get_name = */ ggml_backend_multi_buffer_get_name,
- /* .free_buffer = */ ggml_backend_multi_buffer_free_buffer,
- /* .get_base = */ NULL,
- /* .init_tensor = */ NULL,
- /* .memset_tensor = */ NULL,
- /* .set_tensor = */ NULL,
- /* .get_tensor = */ NULL,
- /* .cpy_tensor = */ NULL,
- /* .clear = */ ggml_backend_multi_buffer_clear,
- /* .reset = */ NULL,
- };
-
- return multi_backend_buffer_i;
-}
-
-GGML_CALL ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer_t * buffers, size_t n_buffers) {
- ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) malloc(sizeof(struct ggml_backend_multi_buffer_context));
- ctx->n_buffers = n_buffers;
- ctx->buffers = (ggml_backend_buffer_t *) malloc(n_buffers * sizeof(ggml_backend_buffer_t));
-
- GGML_ASSERT(ctx->buffers != NULL);
-
- size_t total_size = 0;
- for (size_t i = 0; i < n_buffers; i++) {
- ctx->buffers[i] = buffers[i];
- total_size += ggml_backend_buffer_get_size(buffers[i]);
- }
-
- return ggml_backend_buffer_init(buffers[0]->buft, ggml_backend_multi_buffer_context_interface(), ctx, total_size);
-}
-
-GGML_CALL bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer) {
- return buffer->iface.get_name == ggml_backend_multi_buffer_get_name;
-}
-
-GGML_CALL void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) {
- GGML_ASSERT(ggml_backend_buffer_is_multi_buffer(buffer));
- ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) buffer->context;
- for (size_t i = 0; i < ctx->n_buffers; i++) {
- ggml_backend_buffer_set_usage(ctx->buffers[i], usage);
- }
-}
-
-// creates a copy of the tensor with the same memory layout
-static struct ggml_tensor * ggml_dup_tensor_layout(struct ggml_context * ctx, const struct ggml_tensor * tensor) {
- struct ggml_tensor * dup = ggml_dup_tensor(ctx, tensor);
- for (int i = 0; i < GGML_MAX_DIMS; i++) {
- dup->nb[i] = tensor->nb[i];
- }
- return dup;
-}
-
-static bool ggml_is_view_op(enum ggml_op op) {
- return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE;
-}
-
-// scheduler
-
-#ifndef GGML_SCHED_MAX_BACKENDS
-#define GGML_SCHED_MAX_BACKENDS 16
-#endif
-
-#ifndef GGML_SCHED_MAX_SPLIT_INPUTS
-#define GGML_SCHED_MAX_SPLIT_INPUTS GGML_MAX_SRC
-#endif
-
-#ifndef GGML_SCHED_MAX_COPIES
-#define GGML_SCHED_MAX_COPIES 4
-#endif
-
-struct ggml_backend_sched_split {
- int backend_id;
- int i_start;
- int i_end;
- struct ggml_tensor * inputs[GGML_SCHED_MAX_SPLIT_INPUTS];
- int n_inputs;
- // graph view of this split
- struct ggml_cgraph graph;
-};
-
-struct ggml_backend_sched {
- bool is_reset; // true if the scheduler has been reset since the last graph split
- bool is_alloc;
-
- int n_backends;
-
- ggml_backend_t backends[GGML_SCHED_MAX_BACKENDS];
- ggml_backend_buffer_type_t bufts[GGML_SCHED_MAX_BACKENDS];
- ggml_gallocr_t galloc;
-
- // hash map of the nodes in the graph
- struct ggml_hash_set hash_set;
- int * hv_tensor_backend_ids; // [hash_set.size]
- struct ggml_tensor ** hv_tensor_copies; // [hash_set.size][n_backends][n_copies]
-
- int * node_backend_ids; // [graph_size]
- int * leaf_backend_ids; // [graph_size]
-
- int * prev_node_backend_ids; // [graph_size]
- int * prev_leaf_backend_ids; // [graph_size]
-
- // copy of the graph with modified inputs
- struct ggml_cgraph graph;
-
- // graph splits
- struct ggml_backend_sched_split * splits;
- int n_splits;
- int splits_capacity;
-
- // pipeline parallelism support
- int n_copies;
- int cur_copy;
- ggml_backend_event_t events[GGML_SCHED_MAX_BACKENDS][GGML_SCHED_MAX_COPIES];
- struct ggml_tensor * graph_inputs[GGML_SCHED_MAX_SPLIT_INPUTS];
- int n_graph_inputs;
-
- struct ggml_context * ctx;
-
- ggml_backend_sched_eval_callback callback_eval;
- void * callback_eval_user_data;
-
- char * context_buffer;
- size_t context_buffer_size;
-
- bool debug;
-};
-
-#define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor)
-#define tensor_backend_id(tensor) sched->hv_tensor_backend_ids[hash_id(tensor)]
-#define tensor_id_copy(id, backend_id, copy_id) sched->hv_tensor_copies[(id) * sched->n_backends * sched->n_copies + (backend_id) * sched->n_copies + (copy_id)]
-#define tensor_copy(tensor, backend_id, copy_id) tensor_id_copy(hash_id(tensor), backend_id, copy_id)
-
-// returns the priority of the backend, lower id is higher priority
-static int ggml_backend_sched_backend_id(ggml_backend_sched_t sched, ggml_backend_t backend) {
- for (int i = 0; i < sched->n_backends; i++) {
- if (sched->backends[i] == backend) {
- return i;
- }
- }
- return -1;
-}
-
-static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, const struct ggml_tensor * tensor, const struct ggml_tensor * op) {
- ggml_backend_buffer_t buffer = tensor->buffer;
- if (buffer == NULL) {
- return -1;
- }
-
- // find highest prio backend that supports the buffer type and the op
- for (int i = 0; i < sched->n_backends; i++) {
- if (ggml_backend_supports_buft(sched->backends[i], buffer->buft) &&
- ggml_backend_supports_op(sched->backends[i], op)) {
- return i;
- }
- }
-
-#ifndef NDEBUG
- fprintf(stderr, "%s: warning: no backend supports op %s with a weight with buffer type %s used in tensor %s, the weight will need to be copied\n",
- __func__, ggml_op_desc(tensor), ggml_backend_buffer_name(buffer), tensor->name);
-#endif
-
- return -1;
-}
-
-#if 0
-#define GGML_SCHED_MAX_SPLITS_DEBUG 4096
-static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS_DEBUG*GGML_SCHED_MAX_SPLIT_INPUTS][128]; // debug only
-#define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__)
-#define GET_CAUSE(node) causes[hash_id(node)]
-#else
-#define SET_CAUSE(node, ...)
-#define GET_CAUSE(node) ""
-#endif
-
-// returns the backend that should be used for the node based on the current locations
-static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * tensor) {
- // TODO: use supports_op to check if the backend supports the op
-
- // assign pre-allocated nodes to their backend
- int cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor, tensor);
- if (cur_backend_id != -1) {
- SET_CAUSE(tensor, "1.dst");
- return cur_backend_id;
- }
-
- // view_src
- if (tensor->view_src != NULL) {
- cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor->view_src, tensor);
- if (cur_backend_id != -1) {
- SET_CAUSE(tensor, "1.vsrc");
- return cur_backend_id;
- }
- }
-
- if (tensor->buffer || (tensor->view_src && tensor->view_src->buffer)) {
- // since the tensor is pre-allocated, it cannot be moved to another backend
- GGML_ABORT("pre-allocated tensor in a backend that cannot run the operation");
- }
-
- // graph input
- if (tensor->flags & GGML_TENSOR_FLAG_INPUT) {
- cur_backend_id = sched->n_backends - 1; // last backend (assumed CPU)
- SET_CAUSE(tensor, "1.inp");
- return cur_backend_id;
- }
-
- // operations with weights are preferably run on the same backend as the weights
- for (int i = 0; i < GGML_MAX_SRC; i++) {
- const struct ggml_tensor * src = tensor->src[i];
- if (src == NULL) {
- continue;
- }
- if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
- int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor);
- // check if a backend with higher prio wants to offload the op
- if (src_backend_id == sched->n_backends - 1) {
- for (int b = 0; b < src_backend_id; b++) {
- if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) {
- SET_CAUSE(tensor, "1.off");
- return b;
- }
- }
- }
- SET_CAUSE(tensor, "1.wgt%d", i);
- return src_backend_id;
- }
- }
-
- return -1;
-}
-
-static char * fmt_size(size_t size) {
- static char buffer[128];
- if (size >= 1024*1024) {
- snprintf(buffer, sizeof(buffer), "%zuM", size/1024/1024);
- } else {
- snprintf(buffer, sizeof(buffer), "%zuK", size/1024);
- }
- return buffer;
-}
-
-static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
- int cur_split = 0;
- for (int i = 0; i < graph->n_nodes; i++) {
- if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) {
- ggml_backend_t split_backend = sched->backends[sched->splits[cur_split].backend_id];
- fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend),
- sched->splits[cur_split].n_inputs);
- for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) {
- fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name,
- fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j])));
- }
- fprintf(stderr, "\n");
- cur_split++;
- }
- struct ggml_tensor * node = graph->nodes[i];
- if (ggml_is_view_op(node->op)) {
- continue;
- }
- ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
- fprintf(stderr, "node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, ggml_op_name(node->op), node->name,
- fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node));
- for (int j = 0; j < GGML_MAX_SRC; j++) {
- struct ggml_tensor * src = node->src[j];
- if (src == NULL) {
- continue;
- }
- ggml_backend_t src_backend = ggml_backend_sched_get_tensor_backend(sched, src);
- fprintf(stderr, " %20.20s (%5.5s) [%5.5s %8.8s]", src->name,
- fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", GET_CAUSE(src));
- }
- fprintf(stderr, "\n");
- }
-}
-
-static bool ggml_backend_sched_buffer_supported(ggml_backend_sched_t sched, struct ggml_tensor * t, int backend_id) {
- ggml_backend_buffer_t buf = t->view_src ? t->view_src->buffer : t->buffer;
- ggml_backend_buffer_type_t buft = NULL;
-
- if (buf) {
- // the tensor is already allocated
- buft = buf->buft;
- } else {
- // see if the tensor already has a backend assigned, and use the buffer type of that backend
- int tensor_backend_id = tensor_backend_id(t);
- if (tensor_backend_id == -1 && t->view_src) {
- tensor_backend_id = tensor_backend_id(t->view_src);
- }
- if (tensor_backend_id != -1) {
- buft = sched->bufts[tensor_backend_id];
- }
- }
-
- return buft != NULL && ggml_backend_supports_buft(sched->backends[backend_id], buft);
-}
-
-static void ggml_backend_sched_set_if_supported(ggml_backend_sched_t sched, struct ggml_tensor * node, int cur_backend_id, int * node_backend_id) {
- if (ggml_backend_supports_op(sched->backends[cur_backend_id], node)) {
- *node_backend_id = cur_backend_id;
- SET_CAUSE(node, "2.sup");
- }
-}
-
-// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
-static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
- // reset splits
- sched->n_splits = 0;
- sched->n_graph_inputs = 0;
- sched->is_reset = false;
-
- struct ggml_init_params params = {
- /* .mem_size = */ sched->context_buffer_size,
- /* .mem_buffer = */ sched->context_buffer,
- /* .no_alloc = */ true
- };
-
- ggml_free(sched->ctx);
-
- sched->ctx = ggml_init(params);
- if (sched->ctx == NULL) {
- GGML_ABORT("%s: failed to initialize context\n", __func__);
- }
-
- // pass 1: assign backends to ops with pre-allocated inputs
- for (int i = 0; i < graph->n_leafs; i++) {
- struct ggml_tensor * leaf = graph->leafs[i];
- int * leaf_backend_id = &tensor_backend_id(leaf);
- // do not overwrite user assignments
- if (*leaf_backend_id == -1) {
- *leaf_backend_id = ggml_backend_sched_backend_id_from_cur(sched, leaf);
- }
- }
-
- for (int i = 0; i < graph->n_nodes; i++) {
- struct ggml_tensor * node = graph->nodes[i];
- int * node_backend_id = &tensor_backend_id(node);
- // do not overwrite user assignments
- if (*node_backend_id == -1) {
- *node_backend_id = ggml_backend_sched_backend_id_from_cur(sched, node);
-
-#if 0
- // src
- if (node->op == GGML_OP_NONE) {
- continue;
- }
-
- for (int j = 0; j < GGML_MAX_SRC; j++) {
- struct ggml_tensor * src = node->src[j];
- if (src == NULL) {
- continue;
- }
- int * src_backend_id = &tensor_backend_id(src);
- if (*src_backend_id == -1) {
- *src_backend_id = ggml_backend_sched_backend_id_from_cur(sched, src);
- }
- }
-#endif
- }
- }
-
- // pass 2: expand current backend assignments
- // assign the same backend to adjacent nodes
- // expand gpu backends (i.e. non last prio) up and down, ignoring cpu (the lowest priority backend)
- // thus, cpu will never be used unless weights are on cpu, or there are no gpu ops between cpu ops
- // ops unsupported by the backend being expanded will be left unassigned so that they can be assigned later when the locations of its inputs are known
- // expand gpu down
- {
- int cur_backend_id = -1;
- for (int i = 0; i < graph->n_nodes; i++) {
- struct ggml_tensor * node = graph->nodes[i];
- if (ggml_is_view_op(node->op)) {
- continue;
- }
- int * node_backend_id = &tensor_backend_id(node);
- if (*node_backend_id != -1) {
- if (*node_backend_id == sched->n_backends - 1) {
- // skip cpu (lowest prio backend)
- cur_backend_id = -1;
- } else {
- cur_backend_id = *node_backend_id;
- }
- } else if (cur_backend_id != -1) {
- ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
- }
- }
- }
- // expand gpu up
- {
- int cur_backend_id = -1;
- for (int i = graph->n_nodes - 1; i >= 0; i--) {
- struct ggml_tensor * node = graph->nodes[i];
- if (ggml_is_view_op(node->op)) {
- continue;
- }
- int * node_backend_id = &tensor_backend_id(node);
- if (*node_backend_id != -1) {
- if (*node_backend_id == sched->n_backends - 1) {
- // skip cpu (lowest prio backend)
- cur_backend_id = -1;
- } else {
- cur_backend_id = *node_backend_id;
- }
- } else if (cur_backend_id != -1) {
- ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
- }
- }
- }
- // expand rest down
- {
- int cur_backend_id = -1;
- for (int i = 0; i < graph->n_nodes; i++) {
- struct ggml_tensor * node = graph->nodes[i];
- if (ggml_is_view_op(node->op)) {
- continue;
- }
- int * node_backend_id = &tensor_backend_id(node);
- if (*node_backend_id != -1) {
- cur_backend_id = *node_backend_id;
- } else if (cur_backend_id != -1) {
- ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
- }
- }
- }
- // expand rest up
- {
- int cur_backend_id = -1;
- for (int i = graph->n_nodes - 1; i >= 0; i--) {
- struct ggml_tensor * node = graph->nodes[i];
- if (ggml_is_view_op(node->op)) {
- continue;
- }
- int * node_backend_id = &tensor_backend_id(node);
- if (*node_backend_id != -1) {
- cur_backend_id = *node_backend_id;
- } else if (cur_backend_id != -1) {
- ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
- }
- }
- }
-
- // pass 3: upgrade nodes to higher prio backends with compatible buffer types
- // if the tensor is already in the same buffer type (*) as another higher priority backend, we should move it there
- // however, we also need to verify that the sources are in compatible buffer types
- // (*) the actual requirement is more relaxed, the buffer type of the backend should be supported by all the users of this tensor further down the graph
- // however, this is slow to verify, so we have a more strict requirement that the buffer type is the same
- // this is not uncommon since multiple backends can use host memory, with the same buffer type (eg. BLAS and CPU)
- // additionally, set remaining unassigned nodes to the backend with the most supported inputs
- // only nodes that could not be assigned during expansion due to the backend not supporting the op should be unassigned at this point
- for (int i = 0; i < graph->n_nodes; i++) {
- struct ggml_tensor * node = graph->nodes[i];
- if (ggml_is_view_op(node->op)) {
- continue;
- }
- int * node_backend_id = &tensor_backend_id(node);
- if (*node_backend_id == -1) {
- // unassigned node: find the backend with the most supported inputs
- int n_supported_best = -1;
- for (int b = 0; b < sched->n_backends; b++) {
- if (ggml_backend_supports_op(sched->backends[b], node)) {
- int n_supported = 0;
- for (int j = 0; j < GGML_MAX_SRC; j++) {
- struct ggml_tensor * src = node->src[j];
- if (src == NULL) {
- continue;
- }
- if ((tensor_backend_id(src) != -1 || tensor_backend_id(src->view_src) != -1) && ggml_backend_sched_buffer_supported(sched, src, b)) {
- n_supported++;
- }
- }
- if (n_supported > n_supported_best) {
- n_supported_best = n_supported;
- *node_backend_id = b;
- SET_CAUSE(node, "3.best");
- }
- }
- }
- } else {
- // assigned node: upgrade to higher prio backend if possible
- for (int b = 0; b < *node_backend_id; b++) {
- if (sched->bufts[b] == sched->bufts[*node_backend_id] && ggml_backend_supports_op(sched->backends[b], node)) {
- bool supported = true;
- for (int j = 0; j < GGML_MAX_SRC; j++) {
- struct ggml_tensor * src = node->src[j];
- if (src == NULL) {
- continue;
- }
- if (!ggml_backend_sched_buffer_supported(sched, src, b)) {
- supported = false;
- break;
- }
- }
- if (supported) {
- *node_backend_id = b;
- SET_CAUSE(node, "3.upg");
- break;
- }
- }
- }
- }
- }
-
- // pass 4: assign backends to remaining src from dst and view_src
- for (int i = 0; i < graph->n_nodes; i++) {
- struct ggml_tensor * node = graph->nodes[i];
- int * cur_backend_id = &tensor_backend_id(node);
- if (node->view_src != NULL && *cur_backend_id == -1) {
- *cur_backend_id = tensor_backend_id(node->view_src);
- SET_CAUSE(node, "4.vsrc");
- }
- for (int j = 0; j < GGML_MAX_SRC; j++) {
- struct ggml_tensor * src = node->src[j];
- if (src == NULL) {
- continue;
- }
- int * src_backend_id = &tensor_backend_id(src);
- if (*src_backend_id == -1) {
- if (src->view_src != NULL) {
- // views are always on the same backend as the source
- *src_backend_id = tensor_backend_id(src->view_src);
- SET_CAUSE(src, "4.vsrc");
- } else {
- *src_backend_id = *cur_backend_id;
- SET_CAUSE(src, "4.cur");
- }
- }
- }
- }
-
- // pass 5: split graph, find tensors that need to be copied
- {
- int i_split = 0;
- struct ggml_backend_sched_split * split = &sched->splits[0];
- // find the backend of the first split, skipping view ops
- int i = 0;
- for (; i < graph->n_nodes; i++) {
- struct ggml_tensor * node = graph->nodes[i];
- if (!ggml_is_view_op(node->op)) {
- split->backend_id = tensor_backend_id(node);
- break;
- }
- }
- split->i_start = 0;
- split->n_inputs = 0;
- int cur_backend_id = split->backend_id;
- for (; i < graph->n_nodes; i++) {
- struct ggml_tensor * node = graph->nodes[i];
-
- if (ggml_is_view_op(node->op)) {
- continue;
- }
-
- const int node_backend_id = tensor_backend_id(node);
-
- assert(node_backend_id != -1); // all nodes should be assigned by now
-
- // check if we should start a new split based on the sources of the current node
- bool need_new_split = false;
- if (node_backend_id == cur_backend_id && split->n_inputs > 0) {
- for (int j = 0; j < GGML_MAX_SRC; j++) {
- struct ggml_tensor * src = node->src[j];
- if (src == NULL) {
- continue;
- }
- // check if a weight is on a different backend
- // by starting a new split, the memory of the previously offloaded weights can be reused
- if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
- int src_backend_id = tensor_backend_id(src);
- if (src_backend_id != cur_backend_id) {
- need_new_split = true;
- break;
- }
- }
- // check if the split has too many inputs
- // FIXME: count the number of inputs instead of only checking when full
- if (split->n_inputs == GGML_SCHED_MAX_SPLIT_INPUTS) {
- const size_t id = hash_id(src);
- int src_backend_id = sched->hv_tensor_backend_ids[id];
- bool supported = ggml_backend_sched_buffer_supported(sched, src, cur_backend_id);
- if (src_backend_id != cur_backend_id && tensor_id_copy(id, cur_backend_id, 0) == NULL && !supported) {
- //printf("starting new split because of too many inputs: node %s, input %s\n", node->name, src->name);
- need_new_split = true;
- break;
- }
- }
- }
- }
-
- if (node_backend_id != cur_backend_id || need_new_split) {
- split->i_end = i;
- i_split++;
- if (i_split >= sched->splits_capacity) {
- sched->splits_capacity *= 2;
- sched->splits = realloc(sched->splits, sched->splits_capacity * sizeof(struct ggml_backend_sched_split));
- GGML_ASSERT(sched->splits != NULL);
- }
- split = &sched->splits[i_split];
- split->backend_id = node_backend_id;
- split->i_start = i;
- split->n_inputs = 0;
- cur_backend_id = node_backend_id;
- }
-
- // find inputs that are not on the same backend
- for (int j = 0; j < GGML_MAX_SRC; j++) {
- struct ggml_tensor * src = node->src[j];
- if (src == NULL) {
- continue;
- }
-
- size_t src_id = hash_id(src);
- const int src_backend_id = sched->hv_tensor_backend_ids[src_id];
- assert(src_backend_id != -1); // all inputs should be assigned by now
-
- if (src->flags & GGML_TENSOR_FLAG_INPUT && sched->n_copies > 1) {
- if (tensor_id_copy(src_id, src_backend_id, 0) == NULL) {
- ggml_backend_t backend = sched->backends[src_backend_id];
- for (int c = 0; c < sched->n_copies; c++) {
- struct ggml_tensor * tensor_copy;
- if (c == sched->cur_copy) {
- tensor_copy = src; // use the original tensor as the current copy
- } else {
- tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
- ggml_format_name(tensor_copy, "%s#%s#%d", ggml_backend_name(backend), src->name, c);
- }
- if (sched->n_copies > 1) {
- ggml_set_input(tensor_copy);
- ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor
- }
- tensor_id_copy(src_id, src_backend_id, c) = tensor_copy;
- SET_CAUSE(tensor_copy, "4.cpy");
- }
- int n_graph_inputs = sched->n_graph_inputs++;
- GGML_ASSERT(n_graph_inputs < GGML_SCHED_MAX_SPLIT_INPUTS);
- sched->graph_inputs[n_graph_inputs] = src;
- }
- }
-
- if (src_backend_id != cur_backend_id && !ggml_backend_sched_buffer_supported(sched, src, cur_backend_id)) {
- // create a copy of the input in the split's backend
- if (tensor_id_copy(src_id, cur_backend_id, 0) == NULL) {
- ggml_backend_t backend = sched->backends[cur_backend_id];
- for (int c = 0; c < sched->n_copies; c++) {
- struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
- ggml_format_name(tensor_copy, "%s#%s#%d", ggml_backend_name(backend), src->name, c);
- if (sched->n_copies > 1) {
- ggml_set_input(tensor_copy);
- ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor
- }
- tensor_id_copy(src_id, cur_backend_id, c) = tensor_copy;
- SET_CAUSE(tensor_copy, "4.cpy");
- }
- int n_inputs = split->n_inputs++;
- GGML_ASSERT(n_inputs < GGML_SCHED_MAX_SPLIT_INPUTS);
- split->inputs[n_inputs] = src;
- }
- node->src[j] = tensor_id_copy(src_id, cur_backend_id, sched->cur_copy);
- }
- }
- }
- split->i_end = graph->n_nodes;
- sched->n_splits = i_split + 1;
- }
-
- if (sched->debug) {
- ggml_backend_sched_print_assignments(sched, graph);
- }
-
- // swap node_backend_ids and leaf _backend_ids with prevs
- {
- int * tmp = sched->node_backend_ids;
- sched->node_backend_ids = sched->prev_node_backend_ids;
- sched->prev_node_backend_ids = tmp;
-
- tmp = sched->leaf_backend_ids;
- sched->leaf_backend_ids = sched->prev_leaf_backend_ids;
- sched->prev_leaf_backend_ids = tmp;
- }
-
- int graph_size = MAX(graph->n_nodes, graph->n_leafs) + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sched->n_copies;
- if (sched->graph.size < graph_size) {
- sched->graph.size = graph_size;
- sched->graph.nodes = realloc(sched->graph.nodes, graph_size * sizeof(struct ggml_tensor *));
- sched->graph.leafs = realloc(sched->graph.leafs, graph_size * sizeof(struct ggml_tensor *));
- GGML_ASSERT(sched->graph.nodes != NULL);
- GGML_ASSERT(sched->graph.leafs != NULL);
- }
- sched->graph.n_nodes = 0;
- sched->graph.n_leafs = 0;
-
- struct ggml_cgraph * graph_copy = &sched->graph;
-
- for (int i = 0; i < sched->n_splits; i++) {
- struct ggml_backend_sched_split * split = &sched->splits[i];
- split->graph = ggml_graph_view(graph, split->i_start, split->i_end);
-
- // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
- for (int j = 0; j < split->n_inputs; j++) {
- assert(graph_copy->size > (graph_copy->n_nodes + 1));
-
- struct ggml_tensor * input = split->inputs[j];
- const size_t input_id = hash_id(input);
- struct ggml_tensor * input_cpy = tensor_id_copy(input_id, split->backend_id, sched->cur_copy);
-
- // add a dependency to the input source so that it is not freed before the copy is done
- struct ggml_tensor * input_dep = ggml_view_tensor(sched->ctx, input);
- input_dep->src[0] = input;
- sched->node_backend_ids[graph_copy->n_nodes] = sched->hv_tensor_backend_ids[input_id];
- graph_copy->nodes[graph_copy->n_nodes++] = input_dep;
-
- // add a dependency to the input copy so that it is allocated at the start of the split
- sched->node_backend_ids[graph_copy->n_nodes] = split->backend_id;
- graph_copy->nodes[graph_copy->n_nodes++] = input_cpy;
- }
-
- for (int j = split->i_start; j < split->i_end; j++) {
- assert(graph_copy->size > graph_copy->n_nodes);
- sched->node_backend_ids[graph_copy->n_nodes] = tensor_backend_id(graph->nodes[j]);
- graph_copy->nodes[graph_copy->n_nodes++] = graph->nodes[j];
- }
- }
-
- if (sched->n_copies > 1) {
- // add input copies as leafs so that they are allocated first
- for (int i = 0; i < sched->n_graph_inputs; i++) {
- struct ggml_tensor * input = sched->graph_inputs[i];
- size_t id = hash_id(input);
- int backend_id = tensor_backend_id(input);
- for (int c = 0; c < sched->n_copies; c++) {
- struct ggml_tensor * input_cpy = tensor_id_copy(id, backend_id, c);
- sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id;
- assert(graph_copy->size > graph_copy->n_leafs);
- graph_copy->leafs[graph_copy->n_leafs++] = input_cpy;
- }
- }
-
- for (int i = 0; i < sched->n_splits; i++) {
- struct ggml_backend_sched_split * split = &sched->splits[i];
- int backend_id = split->backend_id;
- for (int j = 0; j < split->n_inputs; j++) {
- struct ggml_tensor * input = split->inputs[j];
- size_t id = hash_id(input);
- for (int c = 0; c < sched->n_copies; c++) {
- struct ggml_tensor * input_cpy = tensor_id_copy(id, backend_id, c);
- sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id;
- assert(graph_copy->size > graph_copy->n_leafs);
- graph_copy->leafs[graph_copy->n_leafs++] = input_cpy;
- }
- }
- }
- }
-
- // add leafs from the original graph
- for (int i = 0; i < graph->n_leafs; i++) {
- struct ggml_tensor * leaf = graph->leafs[i];
- sched->leaf_backend_ids[graph_copy->n_leafs] = tensor_backend_id(leaf);
- assert(graph_copy->size > graph_copy->n_leafs);
- graph_copy->leafs[graph_copy->n_leafs++] = leaf;
- }
-}
-
-static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
- bool backend_ids_changed = false;
- for (int i = 0; i < sched->graph.n_nodes; i++) {
- if (sched->node_backend_ids[i] != sched->prev_node_backend_ids[i] &&
- sched->bufts[sched->node_backend_ids[i]] != sched->bufts[sched->prev_node_backend_ids[i]]) {
- backend_ids_changed = true;
- break;
- }
- }
- if (!backend_ids_changed) {
- for (int i = 0; i < sched->graph.n_leafs; i++) {
- if (sched->leaf_backend_ids[i] != sched->prev_leaf_backend_ids[i] &&
- sched->bufts[sched->leaf_backend_ids[i]] != sched->bufts[sched->prev_leaf_backend_ids[i]]) {
- backend_ids_changed = true;
- break;
- }
- }
- }
-
- // allocate graph
- if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
- // the re-allocation may cause the split inputs to be moved to a different address
- ggml_backend_sched_synchronize(sched);
-#ifndef NDEBUG
- fprintf(stderr, "%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed);
-#endif
- ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids);
- if (!ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
- fprintf(stderr, "%s: failed to allocate graph\n", __func__);
- return false;
- }
- }
-
- return true;
-}
-
-static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
- struct ggml_backend_sched_split * splits = sched->splits;
-
- for (int i = 0; i < sched->n_splits; i++) {
- struct ggml_backend_sched_split * split = &splits[i];
- int split_backend_id = split->backend_id;
- ggml_backend_t split_backend = sched->backends[split_backend_id];
-
- // copy the input tensors to the split backend
- for (int j = 0; j < split->n_inputs; j++) {
- ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[j]);
- struct ggml_tensor * input = split->inputs[j];
- struct ggml_tensor * input_cpy = tensor_copy(input, split_backend_id, sched->cur_copy);
-
- if (input->flags & GGML_TENSOR_FLAG_INPUT) {
- // inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done
- if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
- ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
- } else {
- ggml_backend_synchronize(split_backend);
- }
- ggml_backend_tensor_copy(input, input_cpy);
- } else {
- // wait for the split backend to finish using the input before overwriting it
- if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
- ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]);
- } else {
- ggml_backend_synchronize(split_backend);
- }
- // try async copy, but if not possible, we can still use a sync copy without synchronizing the dst backend, since we handle the synchronization here with multiple copies and events
- // TODO: add public function to facilitate this, since applications do not have direct access to the backend interface
- if (!split_backend->iface.cpy_tensor_async || !split_backend->iface.cpy_tensor_async(input_backend, split_backend, input, input_cpy)) {
- ggml_backend_synchronize(input_backend);
- if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
- ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
- } else {
- ggml_backend_synchronize(split_backend);
- }
- ggml_backend_tensor_copy(input, input_cpy);
- }
- }
- }
-
- if (!sched->callback_eval) {
- enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph);
- if (ec != GGML_STATUS_SUCCESS) {
- return ec;
- }
- } else {
- // similar to ggml_backend_compare_graph_backend
- for (int j0 = 0; j0 < split->graph.n_nodes; j0++) {
- struct ggml_tensor * t = split->graph.nodes[j0];
-
- // check if the user needs data from this node
- bool need = sched->callback_eval(t, true, sched->callback_eval_user_data);
-
- int j1 = j0;
-
- // determine the range [j0, j1] of nodes that can be computed together
- while (!need && j1 < split->graph.n_nodes - 1) {
- t = split->graph.nodes[++j1];
- need = sched->callback_eval(t, true, sched->callback_eval_user_data);
- }
-
- struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1);
-
- enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &gv);
- if (ec != GGML_STATUS_SUCCESS) {
- return ec;
- }
-
- // TODO: pass backend to the callback, then the user can decide if they want to synchronize
- ggml_backend_synchronize(split_backend);
-
- if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) {
- break;
- }
-
- j0 = j1;
- }
- }
-
- // record the event of this copy
- if (split->n_inputs > 0) {
- if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
- ggml_backend_event_record(sched->events[split_backend_id][sched->cur_copy]);
- }
- }
- }
-
- sched->cur_copy = (sched->cur_copy + 1) % sched->n_copies;
-
- return GGML_STATUS_SUCCESS;
-}
-
-ggml_backend_sched_t ggml_backend_sched_new(
- ggml_backend_t * backends,
- ggml_backend_buffer_type_t * bufts,
- int n_backends,
- size_t graph_size,
- bool parallel) {
- GGML_ASSERT(n_backends > 0);
- GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS);
- GGML_ASSERT(ggml_backend_is_cpu(backends[n_backends - 1])); // last backend must be CPU
-
- struct ggml_backend_sched * sched = calloc(1, sizeof(struct ggml_backend_sched));
-
- sched->debug = getenv("GGML_SCHED_DEBUG") != NULL;
- sched->n_backends = n_backends;
- sched->n_copies = parallel ? GGML_SCHED_MAX_COPIES : 1;
-
- // initialize hash table
- // FIXME: needs to be size*2 to account for leafs (do it in graph_split instead)
- sched->hash_set = ggml_hash_set_new(graph_size);
- sched->hv_tensor_backend_ids = malloc(sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0]));
- sched->hv_tensor_copies = malloc(sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *));
-
- const size_t ggml_sched_max_splits = graph_size; // at most there is one split for each node in the graph
- const size_t nodes_size = graph_size + ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2;
- sched->node_backend_ids = calloc(nodes_size, sizeof(sched->node_backend_ids[0]));
- sched->leaf_backend_ids = calloc(nodes_size, sizeof(sched->leaf_backend_ids[0]));
- sched->prev_node_backend_ids = calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0]));
- sched->prev_leaf_backend_ids = calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0]));
-
- sched->context_buffer_size = ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + ggml_graph_overhead_custom(graph_size, false);
- sched->context_buffer = malloc(sched->context_buffer_size);
-
- const int initial_splits_capacity = 16;
- sched->splits = calloc(initial_splits_capacity, sizeof(sched->splits[0]));
- sched->splits_capacity = initial_splits_capacity;
-
- for (int b = 0; b < n_backends; b++) {
- sched->backends[b] = backends[b];
- sched->bufts[b] = bufts ? bufts[b] : ggml_backend_get_default_buffer_type(backends[b]);
- GGML_ASSERT(ggml_backend_supports_buft(backends[b], sched->bufts[b]));
- if (sched->n_copies > 1) {
- for (int c = 0; c < sched->n_copies; c++) {
- sched->events[b][c] = ggml_backend_event_new(backends[b]);
- }
- }
- }
-
- sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends);
-
- ggml_backend_sched_reset(sched);
-
- return sched;
-}
-
-void ggml_backend_sched_free(ggml_backend_sched_t sched) {
- if (sched == NULL) {
- return;
- }
- for (int b = 0; b < sched->n_backends; b++) {
- for (int c = 0; c < sched->n_copies; c++) {
- ggml_backend_event_free(sched->events[b][c]);
- }
- }
- ggml_gallocr_free(sched->galloc);
- ggml_free(sched->ctx);
- ggml_hash_set_free(&sched->hash_set);
- free(sched->splits);
- free(sched->hv_tensor_backend_ids);
- free(sched->hv_tensor_copies);
- free(sched->node_backend_ids);
- free(sched->leaf_backend_ids);
- free(sched->prev_node_backend_ids);
- free(sched->prev_leaf_backend_ids);
- free(sched->context_buffer);
- free(sched->graph.nodes);
- free(sched->graph.leafs);
- free(sched);
-}
-
-void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
- // reset state for the next run
- if (!sched->is_reset) {
- ggml_hash_set_reset(&sched->hash_set);
- memset(sched->hv_tensor_backend_ids, -1, sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0]));
- memset(sched->hv_tensor_copies, 0, sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *));
- sched->is_reset = true;
- }
- sched->is_alloc = false;
-}
-
-bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
- GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
-
- ggml_backend_sched_split_graph(sched, measure_graph);
-
- if (!ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids)) {
- return false;
- }
-
- ggml_backend_sched_reset(sched);
- ggml_backend_sched_synchronize(sched);
-
- return true;
-}
-
-bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
- GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs);
-
- ggml_backend_sched_split_graph(sched, graph);
-
-
- if (!ggml_backend_sched_alloc_splits(sched)) {
- return false;
- }
-
- sched->is_alloc = true;
-
- return true;
-}
-
-enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
- enum ggml_status err = ggml_backend_sched_graph_compute_async(sched, graph);
- ggml_backend_sched_synchronize(sched);
- return err;
-}
-
-enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
- if (!sched->is_reset && !sched->is_alloc) {
- ggml_backend_sched_reset(sched);
- }
-
- if (!sched->is_alloc) {
- if (!ggml_backend_sched_alloc_graph(sched, graph)) {
- return GGML_STATUS_ALLOC_FAILED;
- }
- }
-
- return ggml_backend_sched_compute_splits(sched);
-}
-
-void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) {
- for (int i = 0; i < sched->n_backends; i++) {
- ggml_backend_synchronize(sched->backends[i]);
- }
-}
-
-void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {
- sched->callback_eval = callback;
- sched->callback_eval_user_data = user_data;
-}
-
-int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) {
- return sched->n_splits;
-}
-
-int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched) {
- return sched->n_copies;
-}
-
-int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched) {
- return sched->n_backends;
-}
-
-ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i) {
- GGML_ASSERT(i >= 0 && i < sched->n_backends);
- return sched->backends[i];
-}
-
-size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) {
- int backend_index = ggml_backend_sched_backend_id(sched, backend);
- GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
-
- return ggml_gallocr_get_buffer_size(sched->galloc, backend_index);
-}
-
-void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
- int backend_index = ggml_backend_sched_backend_id(sched, backend);
- GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
- tensor_backend_id(node) = backend_index;
- SET_CAUSE(node, "usr");
- sched->is_reset = false;
-}
-
-ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node) {
- int backend_index = tensor_backend_id(node);
- if (backend_index == -1) {
- return NULL;
- }
- return sched->backends[backend_index];
-}
-
-// utils
-
-void ggml_backend_view_init(struct ggml_tensor * tensor) {
- GGML_ASSERT(tensor->buffer == NULL);
- GGML_ASSERT(tensor->view_src != NULL);
- GGML_ASSERT(tensor->view_src->buffer != NULL);
- GGML_ASSERT(tensor->view_src->data != NULL);
-
- tensor->buffer = tensor->view_src->buffer;
- tensor->data = (char *)tensor->view_src->data + tensor->view_offs;
- ggml_backend_buffer_init_tensor(tensor->buffer, tensor);
-}
-
-void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) {
- GGML_ASSERT(tensor->buffer == NULL);
- GGML_ASSERT(tensor->data == NULL);
- GGML_ASSERT(tensor->view_src == NULL);
- GGML_ASSERT(addr >= ggml_backend_buffer_get_base(buffer));
- GGML_ASSERT((char *)addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <=
- (char *)ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer));
-
- tensor->buffer = buffer;
- tensor->data = addr;
- ggml_backend_buffer_init_tensor(buffer, tensor);
-}
-
-static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies,
- struct ggml_context * ctx_allocated, struct ggml_context * ctx_unallocated, struct ggml_tensor * src) {
-
- GGML_ASSERT(src != NULL);
- GGML_ASSERT(src->data && "graph must be allocated");
-
- size_t id = ggml_hash_insert(&hash_set, src);
- if (id == GGML_HASHSET_ALREADY_EXISTS) {
- return node_copies[ggml_hash_find(&hash_set, src)];
- }
-
- struct ggml_tensor * dst = ggml_dup_tensor_layout(src->data && !src->view_src ? ctx_allocated : ctx_unallocated, src);
- if (src->view_src != NULL) {
- dst->view_src = graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, src->view_src);
- dst->view_offs = src->view_offs;
- }
- dst->op = src->op;
- memcpy(dst->op_params, src->op_params, sizeof(dst->op_params));
- ggml_set_name(dst, src->name);
-
- // copy src
- for (int i = 0; i < GGML_MAX_SRC; i++) {
- struct ggml_tensor * s = src->src[i];
- if (s == NULL) {
- continue;
- }
- dst->src[i] = graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, s);
- }
-
- node_copies[id] = dst;
- return dst;
-}
-
-static void graph_copy_init_tensor(struct ggml_hash_set * hash_set, struct ggml_tensor ** node_copies, bool * node_init, struct ggml_tensor * src) {
- size_t id = ggml_hash_find(hash_set, src);
- if (node_init[id]) {
- return;
- }
- node_init[id] = true;
-
- struct ggml_tensor * dst = node_copies[id];
- if (dst->view_src != NULL) {
- graph_copy_init_tensor(hash_set, node_copies, node_init, src->view_src);
- ggml_backend_view_init(dst);
- }
- else {
- ggml_backend_tensor_copy(src, dst);
- }
-
- // init src
- for (int i = 0; i < GGML_MAX_SRC; i++) {
- struct ggml_tensor * s = src->src[i];
- if (s == NULL) {
- continue;
- }
- graph_copy_init_tensor(hash_set, node_copies, node_init, s);
- }
-}
-
-struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) {
- struct ggml_hash_set hash_set = ggml_hash_set_new(graph->visited_hash_set.size);
- struct ggml_tensor ** node_copies = calloc(hash_set.size, sizeof(node_copies[0])); // NOLINT
- bool * node_init = calloc(hash_set.size, sizeof(node_init[0]));
-
- struct ggml_init_params params = {
- /* .mem_size = */ ggml_tensor_overhead()*hash_set.size + ggml_graph_overhead_custom(graph->size, false),
- /* .mem_buffer = */ NULL,
- /* .no_alloc = */ true
- };
-
- struct ggml_context * ctx_allocated = ggml_init(params);
- struct ggml_context * ctx_unallocated = ggml_init(params);
-
- if (ctx_allocated == NULL || ctx_unallocated == NULL) {
- fprintf(stderr, "failed to allocate context for graph copy\n");
- ggml_hash_set_free(&hash_set);
- free(node_copies);
- free(node_init);
- ggml_free(ctx_allocated);
- ggml_free(ctx_unallocated);
- return (struct ggml_backend_graph_copy) {
- /* .buffer = */ NULL,
- /* .ctx_allocated = */ NULL,
- /* .ctx_unallocated = */ NULL,
- /* .graph = */ NULL,
- };
- }
-
- // dup nodes
- for (int i = 0; i < graph->n_nodes; i++) {
- struct ggml_tensor * node = graph->nodes[i];
- graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, node);
- }
-
- // allocate nodes
- ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx_allocated, backend);
- if (buffer == NULL) {
- fprintf(stderr, "failed to allocate buffer for graph copy\n");
- ggml_hash_set_free(&hash_set);
- free(node_copies);
- free(node_init);
- ggml_free(ctx_allocated);
- ggml_free(ctx_unallocated);
- return (struct ggml_backend_graph_copy) {
- /* .buffer = */ NULL,
- /* .ctx_allocated = */ NULL,
- /* .ctx_unallocated = */ NULL,
- /* .graph = */ NULL,
- };
- }
-
- //printf("copy buffer size: %zu MB\n", ggml_backend_buffer_get_size(buffer) / 1024 / 1024);
-
- // copy data and init views
- for (int i = 0; i < graph->n_nodes; i++) {
- struct ggml_tensor * node = graph->nodes[i];
- graph_copy_init_tensor(&hash_set, node_copies, node_init, node);
- }
-
- // build graph copy
- struct ggml_cgraph * graph_copy = ggml_new_graph_custom(ctx_allocated, graph->size, false);
- for (int i = 0; i < graph->n_nodes; i++) {
- struct ggml_tensor * node = graph->nodes[i];
- struct ggml_tensor * node_copy = node_copies[ggml_hash_find(&hash_set, node)];
- graph_copy->nodes[i] = node_copy;
- }
- graph_copy->n_nodes = graph->n_nodes;
-
- ggml_hash_set_free(&hash_set);
- free(node_copies);
- free(node_init);
-
- return (struct ggml_backend_graph_copy) {
- /* .buffer = */ buffer,
- /* .ctx_allocated = */ ctx_allocated,
- /* .ctx_unallocated = */ ctx_unallocated,
- /* .graph = */ graph_copy,
- };
-}
-
-void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) {
- ggml_backend_buffer_free(copy.buffer);
- ggml_free(copy.ctx_allocated);
- ggml_free(copy.ctx_unallocated);
-}
-
-bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data) {
- struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
- if (copy.buffer == NULL) {
- return false;
- }
-
- struct ggml_cgraph * g1 = graph;
- struct ggml_cgraph * g2 = copy.graph;
-
- assert(g1->n_nodes == g2->n_nodes);
-
- for (int i = 0; i < g1->n_nodes; i++) {
- //printf("eval %d/%d\n", i, g1->n_nodes);
- struct ggml_tensor * t1 = g1->nodes[i];
- struct ggml_tensor * t2 = g2->nodes[i];
-
- assert(t1->op == t2->op && ggml_are_same_layout(t1, t2));
-
- struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1);
- struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1);
-
- ggml_backend_graph_compute(backend1, &g1v);
- ggml_backend_graph_compute(backend2, &g2v);
-
- if (ggml_is_view_op(t1->op)) {
- continue;
- }
-
- // compare results, calculate rms etc
- if (!callback(i, t1, t2, user_data)) {
- break;
- }
- }
-
- ggml_backend_graph_copy_free(copy);
-
- return true;
-}