}
static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
+ const vk_device& device = ggml_vk_get_device(ctx->device);
+
+ // reject any tensors larger than the max buffer size
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ if (op->src[i] && ggml_nbytes(op->src[i]) > device->max_buffer_size) {
+ return false;
+ }
+ }
+ if (ggml_nbytes(op) > device->max_buffer_size) {
+ return false;
+ }
+
switch (op->op) {
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_OP_MUL_MAT_ID:
{
ggml_type src0_type = op->src[0]->type;
- ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
- const vk_device& device = ggml_vk_get_device(ctx->device);
if (op->op == GGML_OP_MUL_MAT_ID) {
if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU
}
case GGML_OP_FLASH_ATTN_EXT:
{
- ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
- auto device = ggml_vk_get_device(ctx->device);
bool coopmat2 = device->coopmat2;
uint32_t HSK = op->src[1]->ne[0];
uint32_t HSV = op->src[2]->ne[0];
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
return false;
}
- ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
- auto device = ggml_vk_get_device(ctx->device);
// pipeline_argsort_large_f32 requires vulkan memory model.
if (device->vulkan_memory_model) {
return true;
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
return false;
}
- ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
- auto device = ggml_vk_get_device(ctx->device);
// We could potentially support larger, using argsort to sort the
// whole thing. Not clear if this is needed.
uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1;
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_CUMSUM:
{
- ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
- auto device = ggml_vk_get_device(ctx->device);
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
}
}
case GGML_OP_SOLVE_TRI:
{
- ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
- const vk_device& device = ggml_vk_get_device(ctx->device);
-
if (op->type != GGML_TYPE_F32 || op->src[0]->type != GGML_TYPE_F32) {
return false;
}
return false;
}
- ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
- const vk_device& device = ggml_vk_get_device(ctx->device);
-
const uint32_t SPLIT_H = 16;
size_t stateC_size = SPLIT_H * d_state * sizeof(float);