int max_nodes = 8192;
ggml_backend_sched_ptr sched;
clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO;
+ bool is_allocated = false;
// for debugging
bool debug_graph = false;
};
static void warmup(clip_ctx & ctx_clip) {
+ // create a fake batch
+ const auto & hparams = ctx_clip.model.hparams;
+ clip_image_f32_batch batch;
+ clip_image_f32_ptr img(clip_image_f32_init());
+ if (ctx_clip.model.modality == CLIP_MODALITY_VISION) {
+ img->nx = hparams.warmup_image_size;
+ img->ny = hparams.warmup_image_size;
+ LOG_INF("%s: warmup with image size = %d x %d\n", __func__, img->nx, img->ny);
+ } else {
+ img->nx = hparams.warmup_audio_size;
+ img->ny = hparams.n_mel_bins;
+ LOG_INF("%s: warmup with audio size = %d\n", __func__, img->nx);
+ }
+ batch.entries.push_back(std::move(img));
+ warmup(ctx_clip, batch);
+ }
+
+ static void warmup(clip_ctx & ctx_clip, const clip_image_f32_batch & batch) {
support_info_graph info;
if (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_AUTO) {
// try to enable flash attention to see if it's supported
ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_ENABLED;
- info = alloc_compute_meta(ctx_clip);
+ info = alloc_compute_meta(ctx_clip, batch);
if (!info.fattn && info.fattn_op) {
auto op = info.fattn_op;
LOG_WRN("%s: *****************************************************************\n", __func__);
LOG_WRN("%s: please report this on github as an issue\n", __func__);
LOG_WRN("%s: *****************************************************************\n", __func__);
ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_DISABLED;
- alloc_compute_meta(ctx_clip);
+ alloc_compute_meta(ctx_clip, batch);
}
} else {
- info = alloc_compute_meta(ctx_clip);
+ info = alloc_compute_meta(ctx_clip, batch);
if (!info.fattn && ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__);
}
}
+ ctx_clip.is_allocated = true; // mark buffers as allocated
+
LOG_INF("%s: flash attention is %s\n", __func__,
(ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled");
}
}
- static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip) {
- const auto & hparams = ctx_clip.model.hparams;
+ static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip, const clip_image_f32_batch & batch) {
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
- // create a fake batch
- clip_image_f32_batch batch;
- clip_image_f32_ptr img(clip_image_f32_init());
- if (ctx_clip.model.modality == CLIP_MODALITY_VISION) {
- img->nx = hparams.warmup_image_size;
- img->ny = hparams.warmup_image_size;
- LOG_INF("%s: warmup with image size = %d x %d\n", __func__, img->nx, img->ny);
- } else {
- img->nx = hparams.warmup_audio_size;
- img->ny = hparams.n_mel_bins;
- LOG_INF("%s: warmup with audio size = %d\n", __func__, img->nx);
- }
- batch.entries.push_back(std::move(img));
-
ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch);
ggml_backend_sched_reserve(ctx_clip.sched.get(), gf);
return false; // only support batch size of 1
}
+ // if buffers are not allocated, we need to do a warmup run to allocate them
+ if (!ctx->is_allocated) {
+ clip_model_loader::warmup(*ctx, *imgs_c_ptr);
+ }
+
// build the inference graph
ctx->debug_print_tensors.clear();
ggml_backend_sched_reset(ctx->sched.get());