vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
+ vk_pipeline pipeline_add_f32, pipeline_add_f16_f32_f16;
vk_pipeline pipeline_mul_f32;
vk_pipeline pipeline_div_f32;
- vk_pipeline pipeline_add_f32;
+ vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
+ vk_pipeline pipeline_upscale_f32;
vk_pipeline pipeline_scale_f32;
vk_pipeline pipeline_sqr_f32;
vk_pipeline pipeline_clamp_f32;
+ vk_pipeline pipeline_pad_f32;
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
vk_pipeline pipeline_norm_f32;
+ vk_pipeline pipeline_group_norm_f32;
vk_pipeline pipeline_rms_norm_f32;
vk_pipeline pipeline_gelu_f32;
+ vk_pipeline pipeline_gelu_quick_f32;
vk_pipeline pipeline_silu_f32;
vk_pipeline pipeline_relu_f32;
+ vk_pipeline pipeline_leaky_relu_f32;
+ vk_pipeline pipeline_tanh_f32;
vk_pipeline pipeline_diag_mask_inf_f32;
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
vk_pipeline pipeline_argsort_f32;
vk_pipeline pipeline_sum_rows_f32;
+ vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
+ vk_pipeline pipeline_timestep_embedding_f32;
std::vector<vk_pipeline_ref> pipelines;
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23;
uint32_t d_offset;
- float param1; float param2;
+ float param1; float param2; int32_t param3;
};
struct vk_op_diag_mask_push_constants {
int32_t order;
};
+struct vk_op_im2col_push_constants {
+ uint32_t batch_offset; uint32_t offset_delta;
+ uint32_t IC;
+ uint32_t IW; uint32_t IH;
+ uint32_t OW; uint32_t OH;
+ uint32_t KW; uint32_t KH;
+ uint32_t pelements;
+ uint32_t CHW;
+ int32_t s0; int32_t s1;
+ int32_t p0; int32_t p1;
+ int32_t d0; int32_t d1;
+};
+
+struct vk_op_timestep_embedding_push_constants {
+ uint32_t nb1;
+ uint32_t dim;
+ uint32_t max_period;
+};
+
// Allow pre-recording command buffers
struct vk_staging_memcpy {
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
size_t n;
};
-struct vk_context {
- size_t idx;
+struct vk_op_upscale_push_constants {
+ uint32_t ne; uint32_t d_offset;
+ uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
+ uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
+ float sf0; float sf1; float sf2; float sf3;
+};
+struct vk_context_struct {
vk_submission * s;
std::vector<vk_sequence> seqs;
- ggml_tensor * exit_tensor;
+ int exit_tensor_idx;
std::vector<vk_staging_memcpy> in_memcpys;
std::vector<vk_staging_memcpy> out_memcpys;
vk_queue * q;
};
+typedef std::shared_ptr<vk_context_struct> vk_context;
+typedef std::weak_ptr<vk_context_struct> vk_context_ref;
struct ggml_tensor_extra_gpu {
- size_t ctx_idx;
-
vk_buffer_ref buffer_gpu;
uint64_t offset;
void reset() {
- ctx_idx = 0;
buffer_gpu.reset();
offset = 0;
}
vk_buffer buffer_pool[MAX_VK_BUFFERS];
- vk_context * compute_ctx;
- vk_context * transfer_ctx;
+ vk_context_ref compute_ctx;
+ vk_context_ref transfer_ctx;
+
+ std::vector<vk_context_ref> tensor_ctxs;
};
#ifdef GGML_VULKAN_MEMORY_DEBUG
static size_t vk_skip_checks;
static size_t vk_output_tensor;
-static void ggml_vk_print_tensor(ggml_backend * ctx, const ggml_tensor * tensor, const char * name);
-static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * tensor);
-static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_tensor * tensor);
+static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name);
+static void ggml_vk_check_results_0(ggml_tensor * tensor);
+static void ggml_vk_check_results_1(ggml_tensor * tensor);
#endif
-typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend);
return s;
}
-static void ggml_vk_submit(vk_context * ctx, vk::Fence fence) {
- VK_LOG_DEBUG("ggml_vk_submit(" << ctx->seqs.size() << ", " << fence << ")");
+static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) {
if (ctx->seqs.empty()) {
return;
}
+ VK_LOG_DEBUG("ggml_vk_submit(" << ctx << ", " << fence << ")");
std::vector<std::vector<uint64_t>> tl_wait_vals;
std::vector<std::vector<uint64_t>> tl_signal_vals;
q.stage_flags = stage_flags;
}
-static vk_context * ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_queue& q) {
- VK_LOG_DEBUG("ggml_vk_create_context()");
- ctx->gc.contexts.emplace_back();
- vk_context * result = &ctx->gc.contexts[ctx->gc.contexts.size() - 1];
- memset((void *) result, 0, sizeof(vk_context));
- result->idx = ctx->gc.contexts.size() - 1;
+static vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_queue& q) {
+ vk_context result = std::make_shared<vk_context_struct>();
+ VK_LOG_DEBUG("ggml_vk_create_context(" << result << ")");
+ ctx->gc.contexts.emplace_back(result);
result->q = &q;
return result;
}
-static vk_context * ggml_vk_create_temporary_context(vk_queue& q) {
- VK_LOG_DEBUG("ggml_vk_create_temporary_context()");
- vk_context * result = new vk_context;
- memset((void *) result, 0, sizeof(vk_context));
- result->idx = 0;
+static vk_context ggml_vk_create_temporary_context(vk_queue& q) {
+ vk_context result = std::make_shared<vk_context_struct>();
+ VK_LOG_DEBUG("ggml_vk_create_temporary_context(" << result << ")");
result->q = &q;
return result;
}
static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) {
VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags) << ", " << to_string(fallback_flags) << ")");
+ if (size > device->max_memory_allocation_size) {
+ throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device memory allocation limit");
+ }
+
std::lock_guard<std::mutex> guard(device->mutex);
vk_buffer buf = std::make_shared<vk_buffer_struct>();
return { buf, 0, VK_WHOLE_SIZE };
}
-static void ggml_vk_sync_buffers(vk_context * ctx) {
+static void ggml_vk_sync_buffers(vk_context& ctx) {
VK_LOG_DEBUG("ggml_vk_sync_buffers()");
const std::vector<vk::MemoryBarrier> mem_barriers{ { { vk::AccessFlagBits::eMemoryRead | vk::AccessFlagBits::eMemoryWrite }, { vk::AccessFlagBits::eMemoryRead | vk::AccessFlagBits::eMemoryWrite } } };
);
}
-static void ggml_vk_wait_events(vk_context * ctx, std::vector<vk::Event>&& events) {
+static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events) {
VK_LOG_DEBUG("ggml_vk_wait_events()");
if (events.empty()) {
return;
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
-
ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_upscale_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {}, 1);
+
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+
ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
}
static vk_device ggml_vk_get_device(size_t idx) {
ctx->staging_size = 0;
ctx->staging_offset = 0;
- ctx->compute_ctx = nullptr;
- ctx->transfer_ctx = nullptr;
-
#ifdef GGML_VULKAN_CHECK_RESULTS
const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS");
vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks));
}
static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) {
- VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline()");
+ VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
return ctx->device->pipeline_matmul_f32;
}
return ctx->device->pipeline_matmul_f16;
}
- GGML_ASSERT(src1_type == GGML_TYPE_F32);
+ if (src1_type != GGML_TYPE_F32) {
+ return nullptr;
+ }
switch (src0_type) {
case GGML_TYPE_Q4_0:
return s;
}
-static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline, std::vector<vk_subbuffer>&& buffers, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
+static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline, std::vector<vk_subbuffer>&& buffers, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]);
s.signal_semaphores = std::move(signal_semaphores);
}
-static void ggml_vk_ctx_end(vk_context * ctx) {
+static void ggml_vk_ctx_end(vk_context& ctx) {
VK_LOG_DEBUG("ggml_vk_ctx_end(" << ctx << ", " << ctx->seqs.size() << ")");
if (ctx->s == nullptr) {
return;
ctx->s = nullptr;
}
-static void ggml_vk_ctx_begin(vk_device& device, vk_context * subctx) {
+static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) {
VK_LOG_DEBUG("ggml_vk_ctx_begin(" << device->name << ")");
if (subctx->s != nullptr) {
ggml_vk_ctx_end(subctx);
}
}
-static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_context * subctx, vk_buffer& dst, size_t offset, const ggml_tensor * tensor, bool sync_staging = false) {
+static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_context& subctx, vk_buffer& dst, size_t offset, const ggml_tensor * tensor, bool sync_staging = false) {
VK_LOG_DEBUG("ggml_vk_buffer_write_nc_async(" << tensor << ")");
GGML_ASSERT(!ggml_is_contiguous(tensor));
// Buffer is already mapped
}
}
-static void ggml_vk_buffer_write_2d_async(vk_context * subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
+static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")");
// Buffer is already mapped
if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
}
}
-static void ggml_vk_buffer_write_async(vk_context * subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
+static void ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")");
return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, staging_buffer, staging_offset, sync_staging);
}
memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width);
}
} else {
- vk_context * subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue);
+ vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue);
ggml_vk_ctx_begin(dst->device, subctx);
ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, nullptr, 0, true);
ggml_vk_ctx_end(subctx);
ggml_vk_submit(subctx, dst->device->fence);
VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences");
dst->device->device.resetFences({ dst->device->fence });
-
- delete subctx;
}
}
ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1);
}
-static void ggml_vk_buffer_read_2d_async(vk_context * subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
+static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
VK_LOG_DEBUG("ggml_vk_buffer_read_2d_async(offset=" << offset << ", width=" << width << ", height=" << height << ")");
GGML_ASSERT(width > 0);
GGML_ASSERT(height > 0);
GGML_ASSERT(src != nullptr);
+ // TODO: staging_offset is not used
+
// Check if dst is pinned memory
vk_buffer buf = nullptr;
size_t buf_offset;
deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys);
}
-static void ggml_vk_buffer_read_async(vk_context * subctx, vk_buffer& src, size_t offset, void * dst, size_t size, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
+static void ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t size, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, staging_buffer, staging_offset, sync_staging);
}
static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) {
- VK_LOG_DEBUG("ggml_vk_buffer_read(" << offset << ", " << size << ")");
+ VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")");
if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);
memcpy(dst, (uint8_t *) src->ptr + offset, size);
} else {
- vk_context * subctx = ggml_vk_create_temporary_context(src->device->transfer_queue);
+ vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue);
ggml_vk_ctx_begin(src->device, subctx);
ggml_vk_buffer_read_async(subctx, src, offset, dst, size, nullptr, 0, true);
ggml_vk_ctx_end(subctx);
for (auto& cpy : subctx->out_memcpys) {
memcpy(cpy.dst, cpy.src, cpy.n);
}
-
- delete subctx;
}
}
-static void ggml_vk_buffer_copy_async(vk_context * ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
+static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")");
// Make sure both buffers are on same device
GGML_ASSERT(src->device == dst->device);
if (src->device == dst->device) {
VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")");
// Copy within the device
- vk_context * subctx = ggml_vk_create_temporary_context(src->device->transfer_queue);
+ vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue);
ggml_vk_ctx_begin(src->device, subctx);
ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size);
ggml_vk_ctx_end(subctx);
ggml_vk_submit(subctx, src->device->fence);
VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences");
src->device->device.resetFences({ src->device->fence });
-
- delete subctx;
} else {
VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")");
// Copy device to device
static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
- vk_context * subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue);
+ vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue);
ggml_vk_ctx_begin(dst->device, subctx);
subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
ggml_vk_ctx_end(subctx);
ggml_vk_submit(subctx, dst->device->fence);
VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_memset waitForFences");
dst->device->device.resetFences({ dst->device->fence });
-
- delete subctx;
}
static uint32_t ggml_vk_guess_split_k(int m, int n, int k) {
}
static void ggml_vk_matmul(
- ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline,
+ ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer,
uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
}
static void ggml_vk_matmul_id(
- ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline,
+ ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
GGML_ABORT("fatal error");
}
-static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline pipeline, const ggml_tensor * tensor, vk_subbuffer&& in, vk_subbuffer&& out) {
+static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor, vk_subbuffer&& in, vk_subbuffer&& out) {
VK_LOG_DEBUG("ggml_vk_cpy_to_contiguous((" << tensor << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), ";
std::cerr << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")");
const int tensor_type_size = ggml_type_size(tensor->type);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, { ne, 1, 1 });
}
-static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
); // NOLINT
}
-static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
VK_LOG_DEBUG("ggml_vk_mul_mat_vec_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
sizeof(vk_mat_vec_push_constants), &pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
}
-static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
VK_LOG_DEBUG("ggml_vk_mul_mat_p021_f16_f32(" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
}
-static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
VK_LOG_DEBUG("ggml_vk_mul_mat_nc_f16_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
}
-static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")");
if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1) {
ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, src0, src1, dst);
}
}
-static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
+static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
VK_LOG_DEBUG("ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
); // NOLINT
}
-static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
+static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
VK_LOG_DEBUG("ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
sizeof(vk_mat_vec_id_push_constants), &pc, { groups_x, (uint32_t)nei0, groups_z });
}
-static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
+static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
VK_LOG_DEBUG("ggml_vk_mul_mat_id(" << src0 << ", " << src1 << ", " << src2 << ", " << dst << ")");
if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst);
}
}
-static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- // guaranteed to be an integer due to the check in ggml_can_repeat
+static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ VK_LOG_DEBUG("ggml_vk_op_repeat(" << src0 << ", " << src1 << ", " << dst << ")");
const uint64_t ne0 = dst->ne[0];
const uint64_t ne1 = dst->ne[1];
const uint64_t ne2 = dst->ne[2];
const uint64_t nb02 = src0->nb[2];
const uint64_t nb03 = src0->nb[3];
+ // guaranteed to be an integer due to the check in ggml_can_repeat
const uint64_t nr0 = ne0/ne00;
const uint64_t nr1 = ne1/ne01;
const uint64_t nr2 = ne2/ne02;
for (uint64_t k1 = 0; k1 < ne01; k1++) {
for (uint64_t i0 = 0; i0 < nr0; i0++) {
copies.push_back({
- src_offset + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0,
- dst_offset + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01,
+ src_offset + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01,
+ dst_offset + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0,
ne00*nb0,
});
}
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
switch (op) {
- case GGML_OP_ADD:
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_add_f32;
- }
- return nullptr;
case GGML_OP_GET_ROWS:
GGML_ASSERT(src1->type == GGML_TYPE_I32);
if (dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_get_rows_f32[src0->type];
}
return nullptr;
+ case GGML_OP_ADD:
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_add_f32;
+ }
+ if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
+ return ctx->device->pipeline_add_f16_f32_f16;
+ }
+ return nullptr;
case GGML_OP_MUL:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_mul_f32;
return ctx->device->pipeline_div_f32;
}
return nullptr;
+ case GGML_OP_CONCAT:
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_concat_f32;
+ }
+ if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+ return ctx->device->pipeline_concat_f16;
+ }
+ if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
+ return ctx->device->pipeline_concat_i32;
+ }
+ return nullptr;
+ case GGML_OP_UPSCALE:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_upscale_f32;
+ }
+ return nullptr;
case GGML_OP_SCALE:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_scale_f32;
return ctx->device->pipeline_clamp_f32;
}
return nullptr;
+ case GGML_OP_PAD:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_pad_f32;
+ }
+ return nullptr;
case GGML_OP_CPY:
case GGML_OP_CONT:
case GGML_OP_DUP:
return ctx->device->pipeline_norm_f32;
}
return nullptr;
+ case GGML_OP_GROUP_NORM:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_group_norm_f32;
+ }
+ return nullptr;
case GGML_OP_RMS_NORM:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_rms_norm_f32;
return ctx->device->pipeline_gelu_f32;
}
break;
+ case GGML_UNARY_OP_GELU_QUICK:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_gelu_quick_f32;
+ }
+ break;
case GGML_UNARY_OP_RELU:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_relu_f32;
}
break;
+ case GGML_UNARY_OP_TANH:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_tanh_f32;
+ }
+ break;
default:
break;
}
return ctx->device->pipeline_sum_rows_f32;
}
return nullptr;
+ case GGML_OP_IM2COL:
+ if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_im2col_f32;
+ }
+ if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
+ return ctx->device->pipeline_im2col_f32_f16;
+ }
+ return nullptr;
+ case GGML_OP_TIMESTEP_EMBEDDING:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_timestep_embedding_f32;
+ }
+ return nullptr;
+ case GGML_OP_LEAKY_RELU:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_leaky_relu_f32;
+ }
+ return nullptr;
default:
return nullptr;
}
case GGML_OP_ADD:
case GGML_OP_MUL:
case GGML_OP_DIV:
+ case GGML_OP_CONCAT:
+ case GGML_OP_UPSCALE:
case GGML_OP_SCALE:
case GGML_OP_SQR:
case GGML_OP_CLAMP:
+ case GGML_OP_PAD:
return true;
default:
return false;
}
template<typename PC>
-static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const PC&& pc) {
+static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const PC&& pc) {
VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
if (src1 != nullptr) {
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
vk_buffer d_D = extra->buffer_gpu.lock();
// Workaround for tiny tensor inputs on ROPE
- if (use_src1 && y_sz > d_D->size) {
+ if (op == GGML_OP_ROPE && use_src1 && y_sz > d_D->size) {
y_sz = VK_WHOLE_SIZE;
}
if (op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))) {
ggml_pipeline_allocate_descriptor_sets(ctx->device, pipeline, 1);
- switch (dst->op) {
+ switch (op) {
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_SOFT_MAX:
case GGML_OP_SUM_ROWS:
- elements = { (uint32_t)ggml_nrows(src0), 1, 1 };
- break;
+ {
+ const uint32_t nr = ggml_nrows(src0);
+ if (nr > 262144) {
+ elements = { 512, 512, CEIL_DIV(nr, 262144) };
+ } else if (nr > 512) {
+ elements = { 512, CEIL_DIV(nr, 512), 1 };
+ } else {
+ elements = { nr, 1, 1 };
+ }
+ } break;
+ case GGML_OP_GROUP_NORM:
+ {
+ const uint32_t num_groups = dst->op_params[0];
+ elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 };
+ } break;
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_ROPE:
elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
case GGML_OP_ARGSORT:
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
break;
+ case GGML_OP_IM2COL:
+ {
+ const bool is_2D = dst->op_params[6] == 1;
+
+ const uint32_t IC = src1->ne[is_2D ? 2 : 1];
+
+ const uint32_t KH = is_2D ? src0->ne[1] : 1;
+ const uint32_t KW = src0->ne[0];
+
+ const uint32_t OH = is_2D ? dst->ne[2] : 1;
+ const uint32_t OW = dst->ne[1];
+
+ const uint32_t batch = src1->ne[3];
+
+ elements = { OW * KW * KH, OH, batch * IC };
+ } break;
+ case GGML_OP_TIMESTEP_EMBEDDING:
+ {
+ const uint32_t dim = dst->op_params[0];
+ uint32_t half_ceil = (dim + 1) / 2;
+ elements = { half_ceil, (uint32_t)src0->ne[0], 1 };
+ } break;
+ case GGML_OP_ADD:
+ case GGML_OP_DIV:
+ case GGML_OP_MUL:
+ case GGML_OP_SCALE:
+ case GGML_OP_SQR:
+ case GGML_OP_CLAMP:
+ case GGML_OP_PAD:
+ case GGML_OP_CPY:
+ case GGML_OP_CONCAT:
+ case GGML_OP_UPSCALE:
+ case GGML_OP_UNARY:
+ {
+ const uint32_t ne = ggml_nelements(dst);
+ if (ne > 262144) {
+ elements = { 512, 512, CEIL_DIV(ne, 262144) };
+ } else if (ne > 512) {
+ elements = { 512, CEIL_DIV(ne, 512), 1 };
+ } else {
+ elements = { ne, 1, 1 };
+ }
+ } break;
default:
elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
break;
if (use_src1) {
subbuf_y = { d_Y, y_buf_offset, y_sz };
} else {
- subbuf_y = { d_X, 0, d_X->size };
+ subbuf_y = { d_X, 0, x_sz };
}
ggml_vk_sync_buffers(subctx);
if (use_src2) {
subbuf_z = { d_Z, z_buf_offset, z_sz };
} else {
- subbuf_z = { d_X, 0, d_X->size };
+ subbuf_z = { d_X, 0, x_sz };
}
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, subbuf_z, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
+ } else if (op == GGML_OP_IM2COL) {
+ // im2col uses only src1 and dst buffers
+ ggml_vk_sync_buffers(subctx);
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_Y, y_buf_offset, y_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
} else if (use_src2) {
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_Z, z_buf_offset, z_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
ggml_pipeline_allocate_descriptor_sets(ctx->device, pipeline, ne02 * ne03);
- switch (dst->op) {
+ switch (op) {
case GGML_OP_NORM:
+ case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
elements = { (uint32_t)ne01, 1, 1 };
break;
}
}
-static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_REPEAT, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
+static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, {});
}
-static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
- 0.0f, 0.0f,
+ 0.0f, 0.0f, 0,
});
}
-static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
- 0.0f, 0.0f,
+ 0.0f, 0.0f, 0,
});
}
-static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
- 0.0f, 0.0f,
+ 0.0f, 0.0f, 0,
});
}
-static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
- 0.0f, 0.0f,
+ 0.0f, 0.0f, 0,
});
}
-static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ int * op_params = (int *)dst->op_params;
+
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONCAT, {
+ (uint32_t)ggml_nelements(dst),
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
+ 0,
+ 0.0f, 0.0f, op_params[0],
+ });
+}
+
+static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
+
+ const float sf0 = (float)dst->ne[0] / src0->ne[0];
+ const float sf1 = (float)dst->ne[1] / src0->ne[1];
+ const float sf2 = (float)dst->ne[2] / src0->ne[2];
+ const float sf3 = (float)dst->ne[3] / src0->ne[3];
+
+ ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
+ (uint32_t)ggml_nelements(dst), 0,
+ (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+ (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
+ sf0, sf1, sf2, sf3,
+ });
+}
+
+static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params;
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
});
}
-static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
});
}
-static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params;
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
});
}
-static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, {
+ (uint32_t)ggml_nelements(dst),
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
+ 0,
+ 0.0f, 0.0f,
+ });
+}
+
+static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
});
}
-static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
}
-static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+ int * op_params = (int *)dst->op_params;
+
+ uint32_t num_groups = op_params[0];
+ uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
+ static const float eps = 1e-6f;
+
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f });
+}
+
+static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
}
-static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
}
-static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
int32_t * op_params = (int32_t *)dst->op_params;
ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] });
}
-static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params;
float scale = op_params[0];
});
}
-static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
+static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
const int n_dims = ((int32_t *) dst->op_params)[1];
// const int mode = ((int32_t *) dst->op_params)[2];
// const int n_ctx = ((int32_t *) dst->op_params)[3];
});
}
-static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
int32_t * op_params = (int32_t *)dst->op_params;
uint32_t ncols = src0->ne[0];
});
}
-static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f });
}
+static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ const int32_t s0 = dst->op_params[0];
+ const int32_t s1 = dst->op_params[1];
+ const int32_t p0 = dst->op_params[2];
+ const int32_t p1 = dst->op_params[3];
+ const int32_t d0 = dst->op_params[4];
+ const int32_t d1 = dst->op_params[5];
+
+ const bool is_2D = dst->op_params[6] == 1;
+
+ const uint32_t IC = src1->ne[is_2D ? 2 : 1];
+ const uint32_t IH = is_2D ? src1->ne[1] : 1;
+ const uint32_t IW = src1->ne[0];
+
+ const uint32_t KH = is_2D ? src0->ne[1] : 1;
+ const uint32_t KW = src0->ne[0];
+
+ const uint32_t OH = is_2D ? dst->ne[2] : 1;
+ const uint32_t OW = dst->ne[1];
+
+ const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
+ const uint32_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
+
+ const uint32_t pelements = OW * KW * KH;
+
+ ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, {
+ batch_offset, offset_delta,
+ IC, IW, IH, OW, OH, KW, KH,
+ pelements,
+ IC * KH * KW,
+ s0, s1, p0, p1, d0, d1,
+ });
+}
+
+static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+ const uint32_t dim = dst->op_params[0];
+ const uint32_t max_period = dst->op_params[1];
+ const uint32_t nb1 = dst->nb[1] / ggml_type_size(dst->type);
+
+ ggml_vk_op_f32<vk_op_timestep_embedding_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, {
+ nb1, dim, max_period,
+ });
+}
+
+static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+ const float * op_params = (const float *)dst->op_params;
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f });
+}
+
#ifdef GGML_VULKAN_RUN_TESTS
static void ggml_vk_print_matrix_area(const void * data, ggml_type type, int ne0, int ne1, int i0, int i1, int i2) {
if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16) {
ggml_vk_buffer_write(d_X, 0, x, sizeof(X_TYPE) * k * m * batch);
ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch);
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
+ vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
for (size_t i = 0; i < num_it; i++) {
ggml_vk_ctx_begin(ctx->device, subctx);
ggml_vk_matmul(
ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
+ vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
ggml_vk_ctx_begin(ctx->device, subctx);
const std::vector<uint32_t> pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne };
ggml_vk_dispatch_pipeline(ctx, subctx, p, { { qx_buf, 0, qx_sz }, { x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1});
ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
ggml_vk_buffer_write(y_buf, 0, y, y_sz);
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
+ vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
for (size_t i = 0; i < num_it; i++) {
ggml_vk_ctx_begin(ctx->device, subctx);
ggml_vk_matmul(
const bool y_f32_kernel = use_src1 && src1->type == GGML_TYPE_F32 && !y_non_contig;
- bool mmp = (use_src0 && use_src1 && src1_type == GGML_TYPE_F32) ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0_type, y_non_contig ? GGML_TYPE_F16 : src1->type) != nullptr : false;
+ bool mmp = (use_src0 && use_src1 && (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID)) ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type) != nullptr : false;
const bool qx_needs_dequant = use_src0 && (!mmp || x_non_contig);
const bool qy_needs_dequant = use_src1 && ((src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig);
case GGML_OP_SCALE:
case GGML_OP_SQR:
case GGML_OP_CLAMP:
+ case GGML_OP_PAD:
case GGML_OP_CPY:
case GGML_OP_CONT:
case GGML_OP_DUP:
case GGML_OP_MUL:
case GGML_OP_DIV:
+ case GGML_OP_CONCAT:
+ case GGML_OP_UPSCALE:
case GGML_OP_NORM:
+ case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_ROPE:
case GGML_OP_ARGSORT:
case GGML_OP_SUM_ROWS:
+ case GGML_OP_IM2COL:
+ case GGML_OP_TIMESTEP_EMBEDDING:
+ case GGML_OP_LEAKY_RELU:
break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(node)) {
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_GELU:
+ case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_RELU:
+ case GGML_UNARY_OP_TANH:
break;
default:
return;
break;
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
+ if (
+ x_sz > ctx->device->max_memory_allocation_size ||
+ y_sz > ctx->device->max_memory_allocation_size ||
+ d_sz > ctx->device->max_memory_allocation_size ||
+ split_k_size > ctx->device->max_memory_allocation_size) {
+ GGML_ABORT("Requested preallocation size is too large");
+ }
if (ctx->prealloc_size_x < x_sz) {
ctx->prealloc_size_x = x_sz;
}
}
}
-static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, bool last_node){
+static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, bool last_node){
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) node->extra;
if (ggml_is_empty(node) || extra == nullptr) {
switch (ggml_get_unary_op(node)) {
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_GELU:
+ case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_RELU:
+ case GGML_UNARY_OP_TANH:
break;
default:
return;
case GGML_OP_ADD:
case GGML_OP_MUL:
case GGML_OP_DIV:
+ case GGML_OP_CONCAT:
+ case GGML_OP_UPSCALE:
case GGML_OP_SCALE:
case GGML_OP_SQR:
case GGML_OP_CLAMP:
+ case GGML_OP_PAD:
case GGML_OP_CPY:
case GGML_OP_CONT:
case GGML_OP_DUP:
case GGML_OP_NORM:
+ case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_MUL_MAT_ID:
case GGML_OP_ARGSORT:
case GGML_OP_SUM_ROWS:
+ case GGML_OP_IM2COL:
+ case GGML_OP_TIMESTEP_EMBEDDING:
+ case GGML_OP_LEAKY_RELU:
break;
default:
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
return;
}
- if (ctx->compute_ctx == nullptr) {
- ctx->compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
- ggml_vk_ctx_begin(ctx->device, ctx->compute_ctx);
+ vk_context compute_ctx;
+
+ if (ctx->compute_ctx.expired()) {
+ compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
+ ctx->compute_ctx = compute_ctx;
+ ggml_vk_ctx_begin(ctx->device, compute_ctx);
+ } else {
+ compute_ctx = ctx->compute_ctx.lock();
}
switch (node->op) {
case GGML_OP_REPEAT:
- ggml_vk_repeat(ctx, ctx->compute_ctx, src0, src1, node);
+ ggml_vk_repeat(ctx, compute_ctx, src0, node);
break;
case GGML_OP_GET_ROWS:
- ggml_vk_get_rows(ctx, ctx->compute_ctx, src0, src1, node);
+ ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_ADD:
- ggml_vk_add(ctx, ctx->compute_ctx, src0, src1, node);
+ ggml_vk_add(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_MUL:
- ggml_vk_mul(ctx, ctx->compute_ctx, src0, src1, node);
+ ggml_vk_mul(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_DIV:
- ggml_vk_div(ctx, ctx->compute_ctx, src0, src1, node);
+ ggml_vk_div(ctx, compute_ctx, src0, src1, node);
+
+ break;
+ case GGML_OP_CONCAT:
+ ggml_vk_concat(ctx, compute_ctx, src0, src1, node);
+
+ break;
+ case GGML_OP_UPSCALE:
+ ggml_vk_upscale(ctx, compute_ctx, src0, node);
break;
case GGML_OP_SCALE:
- ggml_vk_scale(ctx, ctx->compute_ctx, src0, node);
+ ggml_vk_scale(ctx, compute_ctx, src0, node);
break;
case GGML_OP_SQR:
- ggml_vk_sqr(ctx, ctx->compute_ctx, src0, node);
+ ggml_vk_sqr(ctx, compute_ctx, src0, node);
break;
case GGML_OP_CLAMP:
- ggml_vk_clamp(ctx, ctx->compute_ctx, src0, node);
+ ggml_vk_clamp(ctx, compute_ctx, src0, node);
+
+ break;
+ case GGML_OP_PAD:
+ ggml_vk_pad(ctx, compute_ctx, src0, node);
break;
case GGML_OP_CPY:
case GGML_OP_CONT:
case GGML_OP_DUP:
- ggml_vk_cpy(ctx, ctx->compute_ctx, src0, node);
+ ggml_vk_cpy(ctx, compute_ctx, src0, node);
break;
case GGML_OP_NORM:
- ggml_vk_norm(ctx, ctx->compute_ctx, src0, node);
+ ggml_vk_norm(ctx, compute_ctx, src0, node);
+
+ break;
+ case GGML_OP_GROUP_NORM:
+ ggml_vk_group_norm(ctx, compute_ctx, src0, node);
break;
case GGML_OP_RMS_NORM:
- ggml_vk_rms_norm(ctx, ctx->compute_ctx, src0, node);
+ ggml_vk_rms_norm(ctx, compute_ctx, src0, node);
break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(node)) {
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_GELU:
+ case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_RELU:
- ggml_vk_unary(ctx, ctx->compute_ctx, src0, node);
+ case GGML_UNARY_OP_TANH:
+ ggml_vk_unary(ctx, compute_ctx, src0, node);
break;
default:
return;
}
break;
case GGML_OP_DIAG_MASK_INF:
- ggml_vk_diag_mask_inf(ctx, ctx->compute_ctx, src0, node);
+ ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node);
break;
case GGML_OP_SOFT_MAX:
- ggml_vk_soft_max(ctx, ctx->compute_ctx, src0, src1, node);
+ ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_ROPE:
- ggml_vk_rope(ctx, ctx->compute_ctx, src0, src1, src2, node);
+ ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node);
break;
case GGML_OP_ARGSORT:
- ggml_vk_argsort(ctx, ctx->compute_ctx, src0, node);
+ ggml_vk_argsort(ctx, compute_ctx, src0, node);
break;
case GGML_OP_SUM_ROWS:
- ggml_vk_sum_rows(ctx, ctx->compute_ctx, src0, node);
+ ggml_vk_sum_rows(ctx, compute_ctx, src0, node);
+
+ break;
+ case GGML_OP_IM2COL:
+ ggml_vk_im2col(ctx, compute_ctx, src0, src1, node);
+
+ break;
+ case GGML_OP_TIMESTEP_EMBEDDING:
+ ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node);
+
+ break;
+ case GGML_OP_LEAKY_RELU:
+ ggml_vk_leaky_relu(ctx, compute_ctx, src0, node);
break;
case GGML_OP_MUL_MAT:
- ggml_vk_mul_mat(ctx, ctx->compute_ctx, src0, src1, node);
+ ggml_vk_mul_mat(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_MUL_MAT_ID:
- ggml_vk_mul_mat_id(ctx, ctx->compute_ctx, src0, src1, src2, node);
+ ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node);
break;
default:
return;
}
- extra->ctx_idx = ctx->compute_ctx->idx;
+ ctx->tensor_ctxs[node_idx] = compute_ctx;
#ifdef GGML_VULKAN_CHECK_RESULTS
// Force context reset on each node so that each tensor ends up in its own context
#endif
if (last_node) {
- ggml_vk_ctx_end(ctx->compute_ctx);
- ctx->compute_ctx->exit_tensor = node;
- ctx->compute_ctx = nullptr;
+ ggml_vk_ctx_end(compute_ctx);
+ compute_ctx->exit_tensor_idx = node_idx;
+ ctx->compute_ctx.reset();
}
}
-static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor){
+static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx){
ggml_tensor_extra_gpu * extra = nullptr;
switch (tensor->op) {
case GGML_OP_GET_ROWS:
case GGML_OP_MUL:
case GGML_OP_DIV:
+ case GGML_OP_CONCAT:
+ case GGML_OP_UPSCALE:
case GGML_OP_SCALE:
case GGML_OP_SQR:
case GGML_OP_CLAMP:
+ case GGML_OP_PAD:
case GGML_OP_CPY:
case GGML_OP_CONT:
case GGML_OP_DUP:
case GGML_OP_NORM:
+ case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_NONE:
case GGML_OP_ARGSORT:
case GGML_OP_SUM_ROWS:
+ case GGML_OP_IM2COL:
+ case GGML_OP_TIMESTEP_EMBEDDING:
+ case GGML_OP_LEAKY_RELU:
+ case GGML_OP_REPEAT:
extra = (ggml_tensor_extra_gpu *) tensor->extra;
break;
switch (ggml_get_unary_op(tensor)) {
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_GELU:
+ case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_RELU:
+ case GGML_UNARY_OP_TANH:
extra = (ggml_tensor_extra_gpu *) tensor->extra;
break;
default:
VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")");
#ifdef GGML_VULKAN_CHECK_RESULTS
- ggml_vk_check_results_0(ctx, tensor);
+ ggml_vk_check_results_0(tensor);
#endif
- vk_context& subctx = ctx->gc.contexts[extra->ctx_idx];
+ vk_context subctx = ctx->tensor_ctxs[tensor_idx].lock();
// Only run if ctx hasn't been submitted yet
- if (!subctx.seqs.empty()) {
+ if (!subctx->seqs.empty()) {
// Do staging buffer copies
- for (auto& cpy : subctx.in_memcpys) {
+ for (auto& cpy : subctx->in_memcpys) {
memcpy(cpy.dst, cpy.src, cpy.n);
}
- ggml_vk_submit(&subctx, ctx->fence);
+ ggml_vk_submit(subctx, ctx->fence);
}
- if (tensor == subctx.exit_tensor) {
+ if (tensor_idx == subctx->exit_tensor_idx) {
VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences");
ctx->device->device.resetFences({ ctx->fence });
// Do staging buffer copies
- for (auto& cpy : subctx.out_memcpys) {
+ for (auto& cpy : subctx->out_memcpys) {
memcpy(cpy.dst, cpy.src, cpy.n);
}
- subctx.in_memcpys.clear();
- subctx.out_memcpys.clear();
+ subctx->in_memcpys.clear();
+ subctx->out_memcpys.clear();
}
return true;
ctx->staging_offset = 0;
- ctx->compute_ctx = nullptr;
- ctx->transfer_ctx = nullptr;
+ ctx->tensor_ctxs.clear();
ctx->gc.contexts.clear();
}
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
- if (ctx->transfer_ctx == nullptr) {
+ vk_context transfer_ctx;
+
+ if (ctx->transfer_ctx.expired()) {
// Initialize new transfer context
- ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
- ggml_vk_ctx_begin(ctx->device, ctx->transfer_ctx);
+ transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
+ ctx->transfer_ctx = transfer_ctx;
+ ggml_vk_ctx_begin(ctx->device, transfer_ctx);
+ } else {
+ transfer_ctx = ctx->transfer_ctx.lock();
}
vk_buffer buf = extra->buffer_gpu.lock();
- ggml_vk_buffer_write_async(ctx->transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size, ctx->staging, ctx->staging_offset);
+ ggml_vk_buffer_write_async(transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size, ctx->staging, ctx->staging_offset);
}
GGML_CALL static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
- if (ctx->transfer_ctx == nullptr) {
+ vk_context transfer_ctx;
+
+ if (ctx->transfer_ctx.expired()) {
// Initialize new transfer context
- ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
- ggml_vk_ctx_begin(ctx->device, ctx->transfer_ctx);
+ transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
+ ctx->transfer_ctx = transfer_ctx;
+ ggml_vk_ctx_begin(ctx->device, transfer_ctx);
+ } else {
+ transfer_ctx = ctx->transfer_ctx.lock();
}
vk_buffer buf = extra->buffer_gpu.lock();
- ggml_vk_buffer_read_async(ctx->transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size, ctx->staging, ctx->staging_offset);
+ ggml_vk_buffer_read_async(transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size, ctx->staging, ctx->staging_offset);
}
GGML_CALL static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) {
ggml_tensor_extra_gpu * src_extra = (ggml_tensor_extra_gpu *) src->extra;
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
- if (ctx->transfer_ctx == nullptr) {
+ vk_context transfer_ctx;
+
+ if (ctx->transfer_ctx.expired()) {
// Initialize new transfer context
- ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
- ggml_vk_ctx_begin(ctx->device, ctx->transfer_ctx);
+ transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
+ ctx->transfer_ctx = transfer_ctx;
+ ggml_vk_ctx_begin(ctx->device, transfer_ctx);
+ } else {
+ transfer_ctx = ctx->transfer_ctx.lock();
}
vk_buffer src_buf = src_extra->buffer_gpu.lock();
vk_buffer dst_buf = dst_extra->buffer_gpu.lock();
- ggml_vk_buffer_copy_async(ctx->transfer_ctx, dst_buf, dst_extra->offset + dst->view_offs, src_buf, src_extra->offset + src->view_offs, ggml_nbytes(src));
+ ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, dst_extra->offset + dst->view_offs, src_buf, src_extra->offset + src->view_offs, ggml_nbytes(src));
return true;
}
GGML_CALL static void ggml_backend_vk_synchronize(ggml_backend_t backend) {
VK_LOG_DEBUG("ggml_backend_vk_synchronize()");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
- if(ctx->transfer_ctx == nullptr) {
+ if(ctx->transfer_ctx.expired()) {
return;
}
- ggml_vk_ctx_end(ctx->transfer_ctx);
+ vk_context transfer_ctx = ctx->transfer_ctx.lock();
+
+ ggml_vk_ctx_end(transfer_ctx);
- for (auto& cpy : ctx->transfer_ctx->in_memcpys) {
+ for (auto& cpy : transfer_ctx->in_memcpys) {
memcpy(cpy.dst, cpy.src, cpy.n);
}
- ggml_vk_submit(ctx->transfer_ctx, ctx->fence);
+ ggml_vk_submit(transfer_ctx, ctx->fence);
VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_backend_vk_synchronize waitForFences");
ctx->device->device.resetFences({ ctx->fence });
- for (auto& cpy : ctx->transfer_ctx->out_memcpys) {
+ for (auto& cpy : transfer_ctx->out_memcpys) {
memcpy(cpy.dst, cpy.src, cpy.n);
}
- ctx->transfer_ctx = nullptr;
+ ctx->transfer_ctx.reset();
}
static bool ggml_vk_is_empty(ggml_tensor * node) {
last_node -= 1;
}
+ // Reserve tensor context space for all nodes
+ ctx->tensor_ctxs.resize(cgraph->n_nodes);
+
for (int i = 0; i < cgraph->n_nodes; i++) {
- ggml_vk_build_graph(ctx,cgraph->nodes[i], i == last_node);
+ ggml_vk_build_graph(ctx, cgraph->nodes[i], i, i == last_node);
}
for (int i = 0; i < cgraph->n_nodes; i++) {
continue;
}
- bool ok = ggml_vk_compute_forward(ctx, node);
+ bool ok = ggml_vk_compute_forward(ctx, node, i);
if (!ok) {
- fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
+ if (node->op == GGML_OP_UNARY) {
+ std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
+ } else {
+ std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
+ }
}
#ifdef GGML_VULKAN_CHECK_RESULTS
else {
- ggml_vk_check_results_1(ctx, node);
+ ggml_vk_check_results_1(node);
}
#endif
GGML_ASSERT(ok);
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_GELU:
+ case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU:
+ case GGML_UNARY_OP_TANH:
return ggml_is_contiguous(op->src[0]);
default:
return false;
}
return false;
} break;
- // case GGML_OP_REPEAT:
- // {
- // ggml_type src0_type = op->src[0]->type;
- // return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
- // } break;
+ case GGML_OP_REPEAT:
+ {
+ ggml_type src0_type = op->src[0]->type;
+ return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
+ } break;
case GGML_OP_ROPE:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_NONE:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_NORM:
+ case GGML_OP_GROUP_NORM:
+ case GGML_OP_RMS_NORM:
case GGML_OP_ADD:
case GGML_OP_MUL:
case GGML_OP_DIV:
- case GGML_OP_RMS_NORM:
+ case GGML_OP_CONCAT:
+ case GGML_OP_UPSCALE:
case GGML_OP_SCALE:
case GGML_OP_SQR:
case GGML_OP_CLAMP:
+ case GGML_OP_PAD:
case GGML_OP_CONT:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_ARGSORT:
case GGML_OP_SUM_ROWS:
+ case GGML_OP_IM2COL:
+ case GGML_OP_TIMESTEP_EMBEDDING:
+ case GGML_OP_LEAKY_RELU:
return true;
default:
return false;
}
}
-static void ggml_vk_print_tensor(ggml_backend_vk_context * ctx, const ggml_tensor * tensor, const char * name) {
+static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name) {
void * tensor_data = tensor->data;
- if (ggml_backend_buffer_is_vk(tensor->buffer)) {
+ const bool is_gpu = tensor->buffer != nullptr && ggml_backend_buffer_is_vk(tensor->buffer);
+
+ if (is_gpu) {
const size_t tensor_size = ggml_nbytes(tensor);
tensor_data = malloc(tensor_size);
std::cerr << std::endl << "Result:" << std::endl;
ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
std::cerr << std::endl;
- std::cerr << std::endl << "Result:" << std::endl;
- ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 1, 0);
- std::cerr << std::endl;
std::vector<const ggml_tensor *> done;
ggml_vk_print_graph_origin(tensor, done);
- if (ggml_backend_buffer_is_vk(tensor->buffer)) {
+ if (is_gpu) {
free(tensor_data);
}
}
size_t comp_size;
size_t comp_nb[GGML_MAX_DIMS];
size_t check_counter = 0;
-static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * tensor) {
- if (tensor->op == GGML_OP_TRANSPOSE) {
+static void ggml_vk_check_results_0(ggml_tensor * tensor) {
+ if (tensor->op == GGML_OP_TRANSPOSE) {
return;
}
ggml_tensor * src2 = tensor->src[2];
struct ggml_init_params iparams = {
- /*.mem_size =*/ 1024*1024*1024,
+ /*.mem_size =*/ 2ul*1024ul*1024ul*1024ul,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ false,
};
}
if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
- ggml_vk_print_tensor(ctx, src0, "src0");
+ ggml_vk_print_tensor(src0, "src0");
}
}
if (src1 != nullptr) {
}
if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
- ggml_vk_print_tensor(ctx, src1, "src1");
- std::cerr << "TENSOR CHECK: " << ggml_op_name(src1_clone->op) << " (check " << check_counter << ")" << std::endl;
- std::cerr << "src1_clone=" << tensor << " src1_clone->type: " << ggml_type_name(src1_clone->type) << " ne0=" << src1_clone->ne[0] << " nb0=" << src1_clone->nb[0] << " ne1=" << src1_clone->ne[1] << " nb1=" << src1_clone->nb[1] << " ne2=" << src1_clone->ne[2] << " nb2=" << src1_clone->nb[2] << " ne3=" << src1_clone->ne[3] << " nb3=" << src1_clone->nb[3] << std::endl;
- if (src1->src[0] != nullptr) {
- std::cerr << "src1->src[0]=" << src1->src[0] << " op=" << ggml_op_name(src1->src[0]->op) << " type=" << ggml_type_name(src1->src[0]->type) << " ne0=" << src1->src[0]->ne[0] << " nb0=" << src1->src[0]->nb[0] << " ne1=" << src1->src[0]->ne[1] << " nb1=" << src1->src[0]->nb[1] << " ne2=" << src1->src[0]->ne[2] << " nb2=" << src1->src[0]->nb[2] << " ne3=" << src1->src[0]->ne[3] << " nb3=" << src1->src[0]->nb[3] << std::endl;
- }
- if (src1->src[1] != nullptr) {
- std::cerr << "src1->src[1]=" << src1->src[1] << " op=" << ggml_op_name(src1->src[1]->op) << " type=" << ggml_type_name(src1->src[1]->type) << " ne0=" << src1->src[1]->ne[0] << " nb0=" << src1->src[1]->nb[0] << " ne1=" << src1->src[1]->ne[1] << " nb1=" << src1->src[1]->nb[1] << " ne2=" << src1->src[1]->ne[2] << " nb2=" << src1->src[1]->nb[2] << " ne3=" << src1->src[1]->ne[3] << " nb3=" << src1->src[1]->nb[3] << std::endl;
- }
- std::cerr << std::endl << "Result:" << std::endl;
- ggml_vk_print_tensor_area(src1_clone, src1_clone->data, 5, 5, 0, 0);
- std::cerr << std::endl;
- std::cerr << std::endl << "Result:" << std::endl;
- ggml_vk_print_tensor_area(src1_clone, src1_clone->data, 5, 5, 1, 0);
- std::cerr << std::endl;
- std::vector<const ggml_tensor *> done;
- ggml_vk_print_graph_origin(src1_clone, done);
+ ggml_vk_print_tensor(src1, "src1");
}
}
if (src2 != nullptr) {
}
if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
- ggml_vk_print_tensor(ctx, src2, "src2");
- std::cerr << "TENSOR CHECK: " << ggml_op_name(src2_clone->op) << " (check " << check_counter << ")" << std::endl;
- std::cerr << "src2_clone=" << tensor << " src2_clone->type: " << ggml_type_name(src2_clone->type) << " ne0=" << src2_clone->ne[0] << " nb0=" << src2_clone->nb[0] << " ne1=" << src2_clone->ne[1] << " nb1=" << src2_clone->nb[1] << " ne2=" << src2_clone->ne[2] << " nb2=" << src2_clone->nb[2] << " ne3=" << src2_clone->ne[3] << " nb3=" << src2_clone->nb[3] << std::endl;
- if (src2->src[0] != nullptr) {
- std::cerr << "src2->src[0]=" << src2->src[0] << " op=" << ggml_op_name(src2->src[0]->op) << " type=" << ggml_type_name(src2->src[0]->type) << " ne0=" << src2->src[0]->ne[0] << " nb0=" << src2->src[0]->nb[0] << " ne1=" << src2->src[0]->ne[1] << " nb1=" << src2->src[0]->nb[1] << " ne2=" << src2->src[0]->ne[2] << " nb2=" << src2->src[0]->nb[2] << " ne3=" << src2->src[0]->ne[3] << " nb3=" << src2->src[0]->nb[3] << std::endl;
- }
- if (src2->src[1] != nullptr) {
- std::cerr << "src2->src[1]=" << src2->src[1] << " op=" << ggml_op_name(src2->src[1]->op) << " type=" << ggml_type_name(src2->src[1]->type) << " ne0=" << src2->src[1]->ne[0] << " nb0=" << src2->src[1]->nb[0] << " ne1=" << src2->src[1]->ne[1] << " nb1=" << src2->src[1]->nb[1] << " ne2=" << src2->src[1]->ne[2] << " nb2=" << src2->src[1]->nb[2] << " ne3=" << src2->src[1]->ne[3] << " nb3=" << src2->src[1]->nb[3] << std::endl;
- }
- std::cerr << std::endl << "Result:" << std::endl;
- ggml_vk_print_tensor_area(src2_clone, src2_clone->data, 5, 5, 0, 0);
- std::cerr << std::endl;
- std::cerr << std::endl << "Result:" << std::endl;
- ggml_vk_print_tensor_area(src2_clone, src2_clone->data, 5, 5, 1, 0);
- std::cerr << std::endl;
- std::vector<const ggml_tensor *> done;
- ggml_vk_print_graph_origin(src2_clone, done);
+ ggml_vk_print_tensor(src2, "src2");
}
}
tensor_clone = ggml_mul(ggml_ctx, src0_clone, src1_clone);
} else if (tensor->op == GGML_OP_DIV) {
tensor_clone = ggml_div(ggml_ctx, src0_clone, src1_clone);
+ } else if (tensor->op == GGML_OP_CONCAT) {
+ tensor_clone = ggml_concat(ggml_ctx, src0_clone, src1_clone, *(int *)tensor->op_params);
+ } else if (tensor->op == GGML_OP_UPSCALE) {
+ tensor_clone = ggml_upscale_ext(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
} else if (tensor->op == GGML_OP_SCALE) {
tensor_clone = ggml_scale(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0]);
} else if (tensor->op == GGML_OP_SQR) {
tensor_clone = ggml_sqr(ggml_ctx, src0_clone);
} else if (tensor->op == GGML_OP_CLAMP) {
tensor_clone = ggml_clamp(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
+ } else if (tensor->op == GGML_OP_PAD) {
+ tensor_clone = ggml_pad(ggml_ctx, src0_clone, tensor->ne[0] - src0_clone->ne[0], tensor->ne[1] - src0_clone->ne[1], tensor->ne[2] - src0_clone->ne[2], tensor->ne[3] - src0_clone->ne[3]);
} else if (tensor->op == GGML_OP_ADD) {
tensor_clone = ggml_add(ggml_ctx, src0_clone, src1_clone);
} else if (tensor->op == GGML_OP_NORM) {
tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
+ } else if (tensor->op == GGML_OP_GROUP_NORM) {
+ tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params);
} else if (tensor->op == GGML_OP_RMS_NORM) {
tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
} else if (tensor->op == GGML_OP_SOFT_MAX) {
const int mode = ((int32_t *) tensor->op_params)[2];
//const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3];
const int n_ctx_orig_ggml = ((int32_t *) tensor->op_params)[4];
- float freq_base = ((float *) tensor->op_params)[5];
- float freq_scale = ((float *) tensor->op_params)[6];
- float ext_factor = ((float *) tensor->op_params)[7];
- float attn_factor = ((float *) tensor->op_params)[8];
- float beta_fast = ((float *) tensor->op_params)[9];
- float beta_slow = ((float *) tensor->op_params)[10];
+ const float freq_base = ((float *) tensor->op_params)[5];
+ const float freq_scale = ((float *) tensor->op_params)[6];
+ const float ext_factor = ((float *) tensor->op_params)[7];
+ const float attn_factor = ((float *) tensor->op_params)[8];
+ const float beta_fast = ((float *) tensor->op_params)[9];
+ const float beta_slow = ((float *) tensor->op_params)[10];
tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
} else if (tensor->op == GGML_OP_UNARY) {
switch (ggml_get_unary_op(tensor)) {
case GGML_UNARY_OP_GELU:
tensor_clone = ggml_gelu(ggml_ctx, src0_clone);
break;
+ case GGML_UNARY_OP_GELU_QUICK:
+ tensor_clone = ggml_gelu_quick(ggml_ctx, src0_clone);
+ break;
case GGML_UNARY_OP_RELU:
tensor_clone = ggml_relu(ggml_ctx, src0_clone);
break;
+ case GGML_UNARY_OP_TANH:
+ tensor_clone = ggml_tanh(ggml_ctx, src0_clone);
+ break;
default:
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
GGML_ABORT("fatal error");
tensor_clone = ggml_argsort(ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params);
} else if (tensor->op == GGML_OP_SUM_ROWS) {
tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone);
+ } else if (tensor->op == GGML_OP_IM2COL) {
+ const int32_t s0 = tensor->op_params[0];
+ const int32_t s1 = tensor->op_params[1];
+ const int32_t p0 = tensor->op_params[2];
+ const int32_t p1 = tensor->op_params[3];
+ const int32_t d0 = tensor->op_params[4];
+ const int32_t d1 = tensor->op_params[5];
+
+ const bool is_2D = tensor->op_params[6] == 1;
+ tensor_clone = ggml_im2col(ggml_ctx, src0_clone, src1_clone, s0, s1, p0, p1, d0, d1, is_2D, tensor->type);
+ } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) {
+ const int32_t dim = tensor->op_params[0];
+ const int32_t max_period = tensor->op_params[1];
+ tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period);
+ } else if (tensor->op == GGML_OP_LEAKY_RELU) {
+ const float * op_params = (const float *)tensor->op_params;
+ tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false);
} else {
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
GGML_ABORT("fatal error");
ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8);
if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
- ggml_vk_print_tensor(ctx, tensor_clone, "tensor_clone");
+ ggml_vk_print_tensor(tensor_clone, "tensor_clone");
}
comp_size = ggml_nbytes(tensor_clone);
}
ggml_free(ggml_ctx);
+
+ VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")");
}
-static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_tensor * tensor) {
+static void ggml_vk_check_results_1(ggml_tensor * tensor) {
if (tensor->op == GGML_OP_TRANSPOSE) {
return;
}
std::cerr << std::endl << "Correct:" << std::endl;
ggml_vk_print_tensor_area(tensor, comp_result, 5, 5, 0, 0);
std::cerr << std::endl;
- std::cerr << std::endl << "Result:" << std::endl;
- ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 1, 0);
- std::cerr << std::endl << "Correct:" << std::endl;
- ggml_vk_print_tensor_area(tensor, comp_result, 5, 5, 1, 0);
- std::cerr << std::endl;
std::vector<const ggml_tensor *> done;
ggml_vk_print_graph_origin(tensor, done);
}
if (ggml_backend_buffer_is_vk(tensor->buffer)) {
free(tensor_data);
}
+
+ VK_LOG_DEBUG("END ggml_vk_check_results_1(" << tensor->name << ")");
}
#endif
+++ /dev/null
-#!/usr/bin/env python
-
-import logging
-import argparse
-import asyncio
-import os
-import sys
-from tempfile import gettempdir, NamedTemporaryFile
-
-logger = logging.getLogger("ggml-vk-generate-shaders")
-
-shader_f32 = """
-#define FLOAT_TYPE float
-"""
-shader_f16 = """
-#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
-#define FLOAT_TYPE float16_t
-"""
-shader_int8_ext = """
-#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
-"""
-
-# Type-specific defines
-shader_f32_defines = """
-#define QUANT_K 1
-#define QUANT_R 1
-
-#define A_TYPE float
-"""
-shader_f16_defines = """
-#define QUANT_K 1
-#define QUANT_R 1
-
-#define A_TYPE float16_t
-"""
-shader_q4_0_defines = """
-#define QUANT_K 32
-#define QUANT_R 2
-
-struct block_q4_0
-{
- float16_t d;
- uint8_t qs[16];
-};
-
-#define A_TYPE block_q4_0
-"""
-shader_q4_1_defines = """
-#define QUANT_K 32
-#define QUANT_R 2
-
-struct block_q4_1
-{
- float16_t d;
- float16_t m;
- uint8_t qs[16];
-};
-
-#define A_TYPE block_q4_1
-"""
-shader_q5_0_defines = """
-#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
-#define QUANT_K 32
-#define QUANT_R 2
-
-struct block_q5_0
-{
- float16_t d;
- uint16_t qh[2];
- uint8_t qs[16];
-};
-
-#define A_TYPE block_q5_0
-"""
-shader_q5_1_defines = """
-#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
-#define QUANT_K 32
-#define QUANT_R 2
-
-struct block_q5_1
-{
- float16_t d;
- float16_t m;
- uint qh;
- uint8_t qs[16];
-};
-
-#define A_TYPE block_q5_1
-"""
-shader_q8_0_defines = """
-#define QUANT_K 32
-#define QUANT_R 1
-
-struct block_q8_0
-{
- float16_t d;
- int8_t qs[32];
-};
-
-#define A_TYPE block_q8_0
-"""
-
-# K-quants
-shader_q2_K_defines = """
-#define QUANT_K 256
-
-struct block_q2_K
-{
- uint8_t scales[QUANT_K/16];
- uint8_t qs[QUANT_K/4];
- f16vec2 d;
-};
-
-#define A_TYPE block_q2_K
-"""
-shader_q3_K_defines = """
-#define QUANT_K 256
-
-struct block_q3_K
-{
- uint8_t hmask[QUANT_K/8];
- uint8_t qs[QUANT_K/4];
- uint8_t scales[12];
- float16_t d;
-};
-
-#define A_TYPE block_q3_K
-"""
-shader_q4_K_defines = """
-#define QUANT_K 256
-
-struct block_q4_K
-{
- f16vec2 d;
- uint8_t scales[3*QUANT_K/64];
- uint8_t qs[QUANT_K/2];
-};
-
-#define A_TYPE block_q4_K
-"""
-shader_q5_K_defines = """
-#define QUANT_K 256
-
-struct block_q5_K
-{
- f16vec2 d;
- uint8_t scales[12];
- uint8_t qh[QUANT_K/8];
- uint8_t qs[QUANT_K/2];
-};
-
-#define A_TYPE block_q5_K
-"""
-shader_q6_K_defines = """
-#define QUANT_K 256
-
-struct block_q6_K
-{
- uint8_t ql[QUANT_K/2];
- uint8_t qh[QUANT_K/4];
- int8_t scales[QUANT_K/16];
- float16_t d;
-};
-
-#define A_TYPE block_q6_K
-"""
-
-# Dequant functions
-shader_float_dequant_func = """
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
- return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
-}
-"""
-
-shader_q4_0_dequant_func = """
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
- const float d = float(data_a[a_offset + ib].d);
- const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
- return (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
-}
-"""
-
-shader_q4_1_dequant_func = """
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
- const float d = float(data_a[a_offset + ib].d);
- const float m = float(data_a[a_offset + ib].m);
- const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
- return vec2(vui & 0xF, vui >> 4) * d + m;
-}
-"""
-
-shader_q5_0_dequant_func = """
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
- const float d = float(data_a[a_offset + ib].d);
- const uint uint_qh = uint(data_a[a_offset + ib].qh[1]) << 16 | data_a[a_offset + ib].qh[0];
- const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
- const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
- return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
-}
-"""
-
-shader_q5_1_dequant_func = """
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
- const float d = float(data_a[a_offset + ib].d);
- const float m = float(data_a[a_offset + ib].m);
- const uint uint_qh = data_a[a_offset + ib].qh;
- const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
- const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
- return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
-}
-"""
-
-shader_q8_0_dequant_func = """
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
- const float d = float(data_a[a_offset + ib].d);
- return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])) * d;
-}
-"""
-
-# MULMAT
-
-mulmat_head = """#version 450
-
-#extension GL_EXT_control_flow_attributes : enable
-#extension GL_EXT_shader_16bit_storage : require
-
-#ifdef MUL_MAT_ID
-#extension GL_EXT_buffer_reference2 : require
-#extension GL_EXT_nonuniform_qualifier : require
-#extension GL_EXT_scalar_block_layout : require
-#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
-
-#define EXPERT_COUNT 8
-#endif
-
-#ifndef LOAD_VEC_A
-#define LOAD_VEC_A 1
-#endif
-#ifndef LOAD_VEC_B
-#define LOAD_VEC_B 1
-#endif
-"""
-
-mulmat_body1 = """
-layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
-
-#ifdef MUL_MAT_ID
-layout (binding = 3) readonly buffer IDS {int data_ids[];};
-#endif
-
-layout (push_constant) uniform parameter
-{
- uint M;
- uint N;
- uint K;
- uint stride_a;
- uint stride_b;
- uint stride_d;
- uint k_split;
-
- uint ne02;
- uint ne12;
- uint broadcast2;
- uint broadcast3;
-
- uint batch_stride_a;
- uint batch_stride_b;
- uint batch_stride_d;
-
-#ifdef MUL_MAT_ID
- uint expert_stride_a;
- uint expert_stride_b0;
- uint expert_stride_b1;
- uint expert_stride_d;
-
- uint ids_stride;
-
- uint n_as;
- uint nei0;
- uint nei1;
- uint nbi1;
- uint ne11;
-#endif
-} p;
-
-layout (constant_id = 1) const uint BM = 64;
-layout (constant_id = 2) const uint BN = 64;
-layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
-layout (constant_id = 4) const uint WM = 32;
-layout (constant_id = 5) const uint WN = 32;
-layout (constant_id = 6) const uint WMITER = 2;
-layout (constant_id = 7) const uint TM = 4;
-layout (constant_id = 8) const uint TN = 2;
-layout (constant_id = 9) const uint WARP = 32;
-
-shared FLOAT_TYPE buf_a[BM * (BK+1)];
-shared FLOAT_TYPE buf_b[BN * (BK+1)];
-
-#ifdef MUL_MAT_ID
-shared u8vec2 rowids[2048];
-#endif
-
-void main() {
-#ifdef MUL_MAT_ID
- const uint batch_idx = gl_GlobalInvocationID.z / p.n_as;
- const uint expert_idx = gl_GlobalInvocationID.z % p.n_as;
-#else
- const uint batch_idx = gl_GlobalInvocationID.z;
-#endif
-
- const uint i13 = batch_idx / p.ne12;
- const uint i12 = batch_idx % p.ne12;
-
- const uint i03 = i13 / p.broadcast3;
- const uint i02 = i12 / p.broadcast2;
-
- const uint batch_idx_a = i03 * p.ne02 + i02;
-
- const uint blocks_m = (p.M + BM - 1) / BM;
- const uint ir = gl_WorkGroupID.x % blocks_m;
- const uint ik = gl_WorkGroupID.x / blocks_m;
- const uint ic = gl_WorkGroupID.y;
-
- const uint warp_i = gl_LocalInvocationID.x / WARP;
- const uint warp_r = warp_i % (BM / WM);
- const uint warp_c = warp_i / (BM / WM);
-
- const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
- const uint WSUBM = WM / WMITER;
- const uint WSUBN = WN / WNITER;
-
- const uint tiw = gl_LocalInvocationID.x % WARP;
- const uint tiwr = tiw % (WSUBM / TM);
- const uint tiwc = tiw / (WSUBM / TM);
-
- const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
- const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
- const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
- const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
-
- const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK;
- const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
-
-#ifdef MUL_MAT_ID
- uint _ne1 = 0;
- for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
- for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
- if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
- rowids[_ne1] = u8vec2(ii0, ii1);
- _ne1++;
- }
- }
- }
-
- const u8vec2 id = rowids[ir * BN + ic];
-#endif
-
- const uint start_k = ik * p.k_split;
- const uint end_k = min(p.K, (ik + 1) * p.k_split);
-
- uint pos_a = (
-#ifdef MUL_MAT_ID
- expert_idx * p.expert_stride_a +
-#endif
- batch_idx_a * p.batch_stride_a + ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
- uint pos_b = (
-#ifdef MUL_MAT_ID
- id.y * p.expert_stride_b1 +
- (id.x % p.ne11) * p.expert_stride_b0 +
-#endif
- batch_idx * p.batch_stride_b +
- ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
-
- float sums[WMITER * TM * WNITER * TN];
- FLOAT_TYPE cache_a[WMITER * TM];
- FLOAT_TYPE cache_b[WNITER * TN];
-
- [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
- sums[i] = 0.0f;
- }
-
- [[unroll]] for (uint block = start_k; block < end_k; block += BK) {
- [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {"""
-
-mulmat_load_scalar = """
-#if LOAD_VEC_A == 8
- const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
- buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x);
- buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
- buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
- buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w);
- buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x);
- buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y);
- buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z);
- buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
-#elif LOAD_VEC_A == 4
- const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
- buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x);
- buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
- buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
- buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
-#else
- if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
- buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
- } else {
- buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(0.0f);
- }
-#endif
-"""
-
-mulmat_load_q4_0 = """
- const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
-
- const uint ib = idx / 16;
- const uint iqs = idx & 0xF;
-
- const float d = float(data_a[ib].d);
- const uint vui = uint(data_a[ib].qs[iqs]);
- const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
-
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
- buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);"""
-
-mulmat_load_q4_1 = """
- const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
-
- const uint ib = idx / 16;
- const uint iqs = idx & 0xF;
-
- const float d = float(data_a[ib].d);
- const float m = float(data_a[ib].m);
- const uint vui = uint(data_a[ib].qs[iqs]);
- const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m;
-
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
- buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);"""
-
-mulmat_load_q5_0 = """
- const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
-
- const uint ib = idx / 16;
- const uint iqs = idx & 0xF;
-
- const float d = float(data_a[ib].d);
- const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
- const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
- const uint vui = uint(data_a[ib].qs[iqs]);
- const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
-
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
- buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);"""
-
-mulmat_load_q5_1 = """
- const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
-
- const uint ib = idx / 16;
- const uint iqs = idx & 0xF;
-
- const float d = float(data_a[ib].d);
- const float m = float(data_a[ib].m);
- const uint uint_qh = data_a[ib].qh;
- const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
- const uint vui = uint(data_a[ib].qs[iqs]);
- const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
-
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
- buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);"""
-
-mulmat_load_q8_0 = """
- const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-
- const uint ib = idx / 16;
- const uint iqs = (idx & 0xF) * 2;
-
- const float d = float(data_a[ib].d);
- const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d;
-
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);"""
-
-
-mulmat_load_q2_K = """
- const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-
- const uint ib = idx / 128; // 2 values per idx
- const uint iqs = idx % 128; // 0..127
-
- const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
- const uint scalesi = iqs / 8; // 0..15
- const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
-
- const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
- const uint scales = data_a[ib].scales[scalesi];
- const vec2 d = vec2(data_a[ib].d);
-
- const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
-
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);"""
-
-mulmat_load_q3_K = """
- const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-
- const uint ib = idx / 128; // 2 values per idx
- const uint iqs = idx % 128; // 0..127
-
- const uint n = iqs / 64; // 0,1
- const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
- const uint hmi = (iqs % 16) * 2; // 0,2,4..30
- const uint j = (iqs % 64) / 4; // 0..3
- const uint is = iqs / 8; // 0..15
- const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3
- const uint qsshift = halfsplit * 2; // 0,2,4,6
- const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
-
- const int8_t us = int8_t(is < 4 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+8] >> 0) & 3) << 4) :
- is < 8 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+4] >> 2) & 3) << 4) :
- is < 12 ? (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is+0] >> 4) & 3) << 4) :
- (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is-4] >> 6) & 3) << 4));
- const float dl = float(data_a[ib].d) * float(us - 32);
-
- buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)));
- buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));"""
-
-mulmat_load_q4_K = """
- const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-
- const uint ib = idx / 128; // 2 values per idx
- const uint iqs = idx % 128; // 0..127
-
- const uint n = iqs / 32; // 0,1,2,3
- const uint b = (iqs % 32) / 16; // 0,1
- const uint is = 2 * n + b; // 0..7
- const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
-
- const vec2 loadd = vec2(data_a[ib].d);
-
- uint8_t sc;
- uint8_t mbyte;
- if (is < 4) {
- sc = uint8_t(data_a[ib].scales[is ] & 63);
- mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
- } else {
- sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
- mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4));
- }
- const float d = loadd.x * sc;
- const float m = loadd.y * mbyte;
-
- buf_a[buf_idx ] = FLOAT_TYPE(d * float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) - m);
- buf_a[buf_idx + 1] = FLOAT_TYPE(d * float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) - m);"""
-
-mulmat_load_q5_K = """
- const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-
- const uint ib = idx / 128; // 2 values per idx
- const uint iqs = idx % 128; // 0..127
-
- const uint n = iqs / 32; // 0,1,2,3
- const uint b = (iqs % 32) / 16; // 0,1
- const uint is = 2 * n + b; // 0..7
- const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
- const uint qhi = (iqs % 16) * 2; // 0,2,4..30
-
- const uint8_t hm = uint8_t(1 << (iqs / 16));
-
- const vec2 loadd = vec2(data_a[ib].d);
-
- uint8_t sc;
- uint8_t mbyte;
- if (is < 4) {
- sc = uint8_t(data_a[ib].scales[is ] & 63);
- mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
- } else {
- sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
- mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4));
- }
- const float d = loadd.x * sc;
- const float m = loadd.y * mbyte;
-
- buf_a[buf_idx ] = FLOAT_TYPE(d * (float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0)) - m);
- buf_a[buf_idx + 1] = FLOAT_TYPE(d * (float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0)) - m);"""
-
-mulmat_load_q6_K = """
- const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-
- const uint ib = idx / 128; // 2 values per idx
- const uint iqs = idx % 128; // 0..127
-
- const uint n = iqs / 64; // 0,1
- const uint b = (iqs % 64) / 32; // 0,1
- const uint is_b = (iqs % 16) / 8; // 0,1
- const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
- const uint is = 8 * n + qhshift + is_b; // 0..15
- const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
- const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
-
- const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]);
-
- buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32));
- buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));"""
-
-mulmat_body2 = """
- }
- [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
-#if LOAD_VEC_B == 8
- const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
- const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B;
- buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x);
- buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y);
- buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z);
- buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w);
- buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x);
- buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y);
- buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z);
- buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
-#elif LOAD_VEC_B == 4
- const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
- const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B;
- buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
- buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
- buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
- buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
-#else
- if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
- buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
- } else {
- buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f);
- }
-#endif
- }
-
- barrier();
-
- pos_a += BK / LOAD_VEC_A;
- pos_b += BK / LOAD_VEC_B;
-
- for (uint i = 0; i < BK; i++) {
- // Load from shared into cache
- [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
- [[unroll]] for (uint j = 0; j < TM; j++) {
- cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
- }
- }
- [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
- [[unroll]] for (uint j = 0; j < TN; j++) {
- cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
- }
- }
-
- [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
- [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
- [[unroll]] for (uint cc = 0; cc < TN; cc++) {
- [[unroll]] for (uint cr = 0; cr < TM; cr++) {
- sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]);
- }
- }
- }
- }
- }
-
- barrier();
- }
-
- const uint dr = ir * BM + warp_r * WM;
- const uint dc = ic * BN + warp_c * WN;
-
- const uint offsets =
-#ifdef MUL_MAT_ID
- expert_idx * p.expert_stride_d +
-#endif
- batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
-
- [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
- [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
-
- const uint dr_warp = dr + wsir * WSUBM + tiwr * TM;
- const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;
- [[unroll]] for (uint cc = 0; cc < TN; cc++) {
- [[unroll]] for (uint cr = 0; cr < TM; cr++) {
- if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
- data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
- }
- }
- }
- }
- }
-}
-"""
-
-mulmat_split_k_reduce_src = """#version 450
-
-#extension GL_EXT_control_flow_attributes : enable
-
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {float data_a[];};
-layout (binding = 1) writeonly buffer D {float data_d[];};
-
-layout (push_constant) uniform parameter {
- uint ne;
- uint k_num;
-} p;
-
-void main() {
- const uint idx = gl_GlobalInvocationID.x;
-
- if (idx >= p.ne) {
- return;
- }
-
- float result = 0.0f;
-
- [[unroll]] for (uint i = 0; i < p.k_num; i++) {
- result += data_a[i * p.ne + idx];
- }
-
- data_d[idx] = result;
-}
-"""
-
-# DEQUANT SHADER
-dequant_head = """#version 450
-
-#extension GL_EXT_control_flow_attributes : require
-#extension GL_EXT_shader_16bit_storage : require
-
-layout (push_constant) uniform parameter
-{
- uint M;
- uint K;
- uint stride_a;
- uint stride_b;
- uint nel;
-} p;
-"""
-
-dequant_f32_body = """
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {float data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
- const uint i = gl_GlobalInvocationID.x * 16;
-
- if (i >= p.nel) {
- return;
- }
-
- [[unroll]] for (uint l = 0; l < 16; l++) {
- data_b[i + l] = D_TYPE(data_a[i + l]);
- }
-}
-"""
-
-dequant_q4_0_body = """
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {block_q4_0 data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
- const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
-
- const uint tid = gl_LocalInvocationID.x % 64;
- const uint il = tid/32;
- const uint ir = tid%32;
- const uint ib = 32*i + ir;
- if (ib >= p.nel / 32) {
- return;
- }
-
- const uint b_idx = 1024*i + 32*ir + 8*il;
-
- const float d = float(data_a[ib].d);
- const float dm = -8.0f * d;
-
- const uint q_idx = 8*il;
-
- [[unroll]] for (uint l = 0; l < 8; ++l) {
- data_b[b_idx + l + 0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + dm);
- data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >> 4) + dm);
- }
-}
-"""
-
-dequant_q4_1_body = """
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {block_q4_1 data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
- const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
-
- const uint tid = gl_LocalInvocationID.x % 64;
- const uint il = tid/32;
- const uint ir = tid%32;
- const uint ib = 32*i + ir;
- if (ib >= p.nel / 32) {
- return;
- }
-
- const uint b_idx = 1024*i + 32*ir + 8*il;
-
- const float d = float(data_a[ib].d);
- const float m = float(data_a[ib].m);
-
- const uint q_idx = 8*il;
-
- [[unroll]] for (uint l = 0; l < 8; ++l) {
- data_b[b_idx + l + 0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + m);
- data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >> 4) + m);
- }
-}
-"""
-
-dequant_q5_0_body = """
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {block_q5_0 data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
- const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
-
- const uint tid = gl_LocalInvocationID.x % 64;
- const uint il = tid/32;
- const uint ir = tid%32;
- const uint ib = 32*i + ir;
- if (ib >= p.nel / 32) {
- return;
- }
-
- const uint b_idx = 1024*i + 32*ir + 8*il;
-
- const float d = float(data_a[ib].d);
- const uint qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
-
- const uint q_idx = 8*il;
-
- [[unroll]] for (uint l = 0; l < 8; ++l) {
- const uint iqs = q_idx + l;
- const uint vui = uint(data_a[ib].qs[iqs]);
- data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10)) - 16.0f));
- data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10)) - 16.0f));
- }
-}
-"""
-
-dequant_q5_1_body = """
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {block_q5_1 data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
- const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
-
- const uint tid = gl_LocalInvocationID.x % 64;
- const uint il = tid/32;
- const uint ir = tid%32;
- const uint ib = 32*i + ir;
- if (ib >= p.nel / 32) {
- return;
- }
-
- const uint b_idx = 1024*i + 32*ir + 8*il;
-
- const float d = float(data_a[ib].d);
- const float m = float(data_a[ib].m);
- const uint qh = data_a[ib].qh;
-
- const uint q_idx = 8*il;
-
- [[unroll]] for (uint l = 0; l < 8; ++l) {
- const uint iqs = q_idx + l;
- const uint vui = uint(data_a[ib].qs[iqs]);
- data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10))) + m);
- data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10))) + m);
- }
-}
-"""
-
-dequant_q8_0_body = """
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {block_q8_0 data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
- const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
-
- const uint tid = gl_LocalInvocationID.x % 64;
- const uint il = tid/32;
- const uint ir = tid%32;
- const uint ib = 32*i + ir;
- if (ib >= p.nel / 32) {
- return;
- }
-
- const uint b_idx = 1024*i + 32*ir + 16*il;
-
- const float d = float(data_a[ib].d);
-
- const uint q_idx = 16*il;
-
- [[unroll]] for (uint l = 0; l < 16; l += 2) {
- data_b[b_idx + l ] = D_TYPE(d * data_a[ib].qs[q_idx + l ]);
- data_b[b_idx + l + 1] = D_TYPE(d * data_a[ib].qs[q_idx + l + 1]);
- }
-}
-"""
-
-# K-quants
-dequant_q2_K_body = """
-layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
- [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
- const uint i = gl_WorkGroupID.x * 256 + wgy;
- if (i >= p.M * p.K / QUANT_K) {
- return;
- }
-
- const uint tid = gl_LocalInvocationID.x;
- const uint ip = tid / 32;
- const uint il = tid - 32 * ip;
- const uint is = 8 * ip + il / 16;
-
- const uint y_idx = i * QUANT_K + 128 * ip + il;
-
- const uint ql_idx = 32 * ip + il;
- const uint8_t qs = data_a[i].qs[32 * ip + il];
-
- FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
- FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
- data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
- data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
- data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));
- data_b[y_idx + 96] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+6] & 0xF) * ((qs >> 6) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+6] >> 4));
- }
-}
-"""
-dequant_q3_K_body = """
-layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
- [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
- const uint i = uint(gl_WorkGroupID.x * 256 + wgy);
- if (i >= p.M * p.K / QUANT_K) {
- return;
- }
-
- const uint r = gl_LocalInvocationID.x / 4;
- const uint tid = r / 2;
- const uint is0 = r % 2;
- const uint l0 = 16 * is0 + 4 * (gl_LocalInvocationID.x % 4);
- const uint n = tid / 4;
- const uint j = tid - 4*n;
-
- const uint8_t m = uint8_t(1 << (4*n + j));
- const uint is = 8*n + 2*j + is0;
- const uint shift = 2*j;
-
- const int8_t us = int8_t(is < 4 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+8] >> 0) & 3) << 4) :
- is < 8 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+4] >> 2) & 3) << 4) :
- is < 12 ? (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is+0] >> 4) & 3) << 4) :
- (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is-4] >> 6) & 3) << 4));
- const FLOAT_TYPE d_all = FLOAT_TYPE(data_a[i].d);
- const FLOAT_TYPE dl = d_all * FLOAT_TYPE(us - 32);
-
- const uint y_idx = i * QUANT_K + 128 * n + 32 * j;
- const uint qs_idx = 32*n;
-
- for (uint l = l0; l < l0 + 4; ++l) {
- data_b[y_idx + l] = D_TYPE(dl * FLOAT_TYPE(int8_t((data_a[i].qs[qs_idx + l] >> shift) & 3) - (((data_a[i].hmask[l] & m) != 0) ? 0 : 4)));
- }
- }
-}
-"""
-dequant_q4_K_body = """
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
- [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
- const uint i = gl_WorkGroupID.x * 256 + wgy;
- if (i >= p.M * p.K / QUANT_K) {
- return;
- }
-
- const uint tid = gl_LocalInvocationID.x;
- const uint il = tid / 8;
- const uint ir = tid % 8;
- const uint is = 2 * il;
- const uint n = 4;
-
- const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
-
- const uint y_idx = i * QUANT_K + 64 * il + n * ir;
- const uint qs_idx = 32*il + n * ir;
-
- uint8_t sc;
- uint8_t m;
- if (is < 4) {
- sc = uint8_t(data_a[i].scales[is] & 63);
- m = uint8_t(data_a[i].scales[is + 4] & 63);
- } else {
- sc = uint8_t((data_a[i].scales[is + 4] & 0xF) | ((data_a[i].scales[is - 4] >> 6) << 4));
- m = uint8_t((data_a[i].scales[is + 4] >> 4) | ((data_a[i].scales[is ] >> 6) << 4));
- }
- const FLOAT_TYPE d1 = dall * sc;
- const FLOAT_TYPE m1 = dmin * m;
-
- if (is < 4) {
- sc = uint8_t(data_a[i].scales[is + 1] & 63);
- m = uint8_t(data_a[i].scales[is + 5] & 63);
- } else {
- sc = uint8_t((data_a[i].scales[is + 5] & 0xF) | ((data_a[i].scales[is - 3] >> 6) << 4));
- m = uint8_t((data_a[i].scales[is + 5] >> 4) | ((data_a[i].scales[is + 1] >> 6) << 4));
- }
- const FLOAT_TYPE d2 = dall * sc;
- const FLOAT_TYPE m2 = dmin * m;
-
- [[unroll]] for (uint l = 0; l < n; ++l) {
- data_b[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] & 0xF) - m1);
- data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] >> 4) - m2);
- }
- }
-}
-"""
-dequant_q5_K_body = """
-layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
- [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
- const uint i = gl_WorkGroupID.x * 256 + wgy;
- if (i >= p.M * p.K / QUANT_K) {
- return;
- }
-
- const uint tid = gl_LocalInvocationID.x;
- const uint il = tid / 16;
- const uint ir = tid % 16;
- const uint is = 2 * il;
-
- const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
-
- const uint y_idx = i * QUANT_K + 64 * il + 2 * ir;
- const uint qs_idx = 32*il + 2 * ir;
- const uint qh_idx = 2 * ir;
-
- uint8_t sc;
- uint8_t m;
- if (is < 4) {
- sc = uint8_t(data_a[i].scales[is] & 63);
- m = uint8_t(data_a[i].scales[is + 4] & 63);
- } else {
- sc = uint8_t((data_a[i].scales[is + 4] & 0xF) | ((data_a[i].scales[is - 4] >> 6) << 4));
- m = uint8_t((data_a[i].scales[is + 4] >> 4) | ((data_a[i].scales[is ] >> 6) << 4));
- }
- const FLOAT_TYPE d1 = dall * sc;
- const FLOAT_TYPE m1 = dmin * m;
-
- if (is < 4) {
- sc = uint8_t(data_a[i].scales[is + 1] & 63);
- m = uint8_t(data_a[i].scales[is + 5] & 63);
- } else {
- sc = uint8_t((data_a[i].scales[is + 5] & 0xF) | ((data_a[i].scales[is - 3] >> 6) << 4));
- m = uint8_t((data_a[i].scales[is + 5] >> 4) | ((data_a[i].scales[is + 1] >> 6) << 4));
- }
- const FLOAT_TYPE d2 = dall * sc;
- const FLOAT_TYPE m2 = dmin * m;
-
- const uint8_t hm1 = uint8_t(1 << (2 * il ));
- const uint8_t hm2 = uint8_t(1 << (2 * il + 1));
- data_b[y_idx ] = D_TYPE(d1 * FLOAT_TYPE((data_a[i].qs[qs_idx ] & 0xF) + (((data_a[i].qh[qh_idx ] & hm1) != 0) ? 16 : 0)) - m1);
- data_b[y_idx + 1] = D_TYPE(d1 * FLOAT_TYPE((data_a[i].qs[qs_idx + 1] & 0xF) + (((data_a[i].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1);
- data_b[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((data_a[i].qs[qs_idx ] >> 4) + (((data_a[i].qh[qh_idx ] & hm2) != 0) ? 16 : 0)) - m2);
- data_b[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((data_a[i].qs[qs_idx + 1] >> 4) + (((data_a[i].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2);
- }
-}
-"""
-dequant_q6_K_body = """
-layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
- [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
- const uint i = gl_WorkGroupID.x * 256 + wgy;
- if (i >= p.M * p.K / QUANT_K) {
- return;
- }
- const uint tid = gl_LocalInvocationID.x;
- const uint ip = tid / 32;
- const uint il = tid - 32 * ip;
- const uint is = 8 * ip + il / 16;
-
- const uint y_idx = i * QUANT_K + 128 * ip + il;
-
- const uint ql_idx = 64 * ip + il;
- const uint8_t qh = data_a[i].qh[32 * ip + il];
-
- const FLOAT_TYPE d = FLOAT_TYPE(data_a[i].d);
-
- data_b[y_idx + 0] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 0] * (int8_t((data_a[i].ql[ql_idx + 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32)));
- data_b[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 2] * (int8_t((data_a[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)));
- data_b[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 4] * (int8_t((data_a[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32)));
- data_b[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 6] * (int8_t((data_a[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32)));
- }
-}
-"""
-
-# Mul Mat Vec
-mul_mat_vec_head = """#version 450
-
-#extension GL_EXT_control_flow_attributes : enable
-#extension GL_EXT_shader_16bit_storage : require
-#extension GL_EXT_shader_8bit_storage : require
-
-#ifdef MUL_MAT_ID
-#define EXPERT_COUNT 8
-#endif
-"""
-
-
-mul_mat_vec_layout = """
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
-#ifdef MUL_MAT_ID
-layout (binding = 3) readonly buffer IDS {int data_ids[];};
-#endif
-
-layout (push_constant) uniform parameter
-{
- uint ncols;
- uint stride_a;
- uint stride_b;
- uint stride_d;
-
- uint ne02;
- uint ne12;
- uint broadcast2;
- uint broadcast3;
-
- uint batch_stride_a;
- uint batch_stride_b;
- uint batch_stride_d;
-
-#ifdef MUL_MAT_ID
- uint expert_stride_a;
- uint expert_stride_b0;
- uint expert_stride_b1;
- uint expert_stride_d0;
- uint expert_stride_d1;
-
- uint ne11;
- uint nei0;
- uint nbi1;
- uint n_as;
-#endif
-} p;
-"""
-
-mul_mat_vec_body = """
-layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-
-layout (constant_id = 0) const uint BLOCK_SIZE = 32;
-
-shared FLOAT_TYPE tmp[BLOCK_SIZE];
-
-void main() {
- const uint row = gl_WorkGroupID.x;
- const uint tid = gl_LocalInvocationID.x;
- const uint batch_idx = gl_GlobalInvocationID.y;
-#ifdef MUL_MAT_ID
- const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0;
- const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0;
-#endif
-
- const uint i13 = batch_idx / p.ne12;
- const uint i12 = batch_idx % p.ne12;
-
- const uint i03 = i13 / p.broadcast3;
- const uint i02 = i12 / p.broadcast2;
-
- const uint batch_idx_a = i03 * p.ne02 + i02;
-
-#ifdef MUL_MAT_ID
- const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0];
-#endif
-
- const uint a_offset =
-#ifdef MUL_MAT_ID
- expert_id * p.expert_stride_a +
-#endif
- batch_idx_a * p.batch_stride_a;
- const uint b_offset =
-#ifdef MUL_MAT_ID
- (expert_idx0 % p.ne11) * p.expert_stride_b0 +
- expert_idx1 * p.expert_stride_b1 +
-#endif
- batch_idx * p.batch_stride_b;
- const uint d_offset =
-#ifdef MUL_MAT_ID
- expert_idx0 * p.expert_stride_b0 +
- expert_idx1 * p.expert_stride_b1 +
-#endif
- batch_idx * p.batch_stride_d;
-
- const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
-
- tmp[tid] = FLOAT_TYPE(0.0f);
-
- [[unroll]] for (uint i = 0; i < p.ncols/BLOCK_SIZE; i += 2) {
- const uint col = i*BLOCK_SIZE + 2*tid;
- const uint ib = (row*p.ncols + col)/QUANT_K; // block index
- const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
- const uint iybs = col - col%QUANT_K; // y block start index
-
- vec2 v = dequantize(ib, iqs, a_offset / QUANT_K);
-
- // matrix multiplication
- tmp[tid] += FLOAT_TYPE(v.x) * FLOAT_TYPE(data_b[b_offset + iybs + iqs]) +
- FLOAT_TYPE(v.y) * FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]);
- }
-
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
- if (tid < s) {
- tmp[tid] += tmp[tid + s];
- }
- barrier();
- }
- if (tid == 0) {
- data_d[d_offset + row] = D_TYPE(tmp[0]);
- }
-}
-"""
-
-# K-quants
-mul_mat_vec_q2_K_body = """
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
-
-shared FLOAT_TYPE tmp[32];
-
-void main() {
- const uint row = gl_WorkGroupID.x;
- const uint batch_idx = gl_GlobalInvocationID.y;
-#ifdef MUL_MAT_ID
- const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0;
- const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0;
-#endif
-
- const uint i13 = batch_idx / p.ne12;
- const uint i12 = batch_idx % p.ne12;
-
- const uint i03 = i13 / p.broadcast3;
- const uint i02 = i12 / p.broadcast2;
-
- const uint batch_idx_a = i03 * p.ne02 + i02;
-
-#ifdef MUL_MAT_ID
- const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0];
-#endif
-
- const uint a_offset =
-#ifdef MUL_MAT_ID
- expert_id * p.expert_stride_a +
-#endif
- batch_idx_a * p.batch_stride_a;
- const uint b_offset =
-#ifdef MUL_MAT_ID
- (expert_idx0 % p.ne11) * p.expert_stride_b0 +
- expert_idx1 * p.expert_stride_b1 +
-#endif
- batch_idx * p.batch_stride_b;
- const uint d_offset =
-#ifdef MUL_MAT_ID
- expert_idx0 * p.expert_stride_b0 +
- expert_idx1 * p.expert_stride_b1 +
-#endif
- batch_idx * p.batch_stride_d;
-
- const uint num_blocks_per_row = p.ncols / QUANT_K;
- const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
-
- const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
- const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
-
- const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
-
- const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
- const uint v_in = tid - step*v_im; // 0...15 or 0...7
-
- const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15
- const uint q_offset = 32*v_im + l0;
- const uint s_offset = 8*v_im;
- const uint y_offset = 128*v_im + l0;
-
- tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
-
- [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
- const uint y_idx = i * QUANT_K + y_offset;
-
- const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
-
- FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
- FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
- for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
- sum1 += FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 0) & 3)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 0) & 3)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 2) & 3)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 2) & 3)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 4) & 3)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 4) & 3)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 6) & 3)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 6) & 3);
- sum2 += FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 0] >> 4) & 0xF)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 1] >> 4) & 0xF)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 2] >> 4) & 0xF)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 3] >> 4) & 0xF)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 4] >> 4) & 0xF)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 5] >> 4) & 0xF)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 6] >> 4) & 0xF)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 7] >> 4) & 0xF);
- }
- tmp[16 * ix + tid] += dall * sum1 - dmin * sum2;
- }
-
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
- if (tid < s) {
- tmp[tid] += tmp[tid + s];
- }
- barrier();
- }
- if (tid == 0) {
- data_d[d_offset + row] = D_TYPE(tmp[0]);
- }
-}
-"""
-mul_mat_vec_q3_K_body = """
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
-
-shared FLOAT_TYPE tmp[32];
-
-void main() {
- const uint row = gl_WorkGroupID.x;
- const uint batch_idx = gl_GlobalInvocationID.y;
-#ifdef MUL_MAT_ID
- const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0;
- const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0;
-#endif
-
- const uint i13 = batch_idx / p.ne12;
- const uint i12 = batch_idx % p.ne12;
-
- const uint i03 = i13 / p.broadcast3;
- const uint i02 = i12 / p.broadcast2;
-
- const uint batch_idx_a = i03 * p.ne02 + i02;
-
-#ifdef MUL_MAT_ID
- const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0];
-#endif
-
- const uint a_offset =
-#ifdef MUL_MAT_ID
- expert_id * p.expert_stride_a +
-#endif
- batch_idx_a * p.batch_stride_a;
- const uint b_offset =
-#ifdef MUL_MAT_ID
- (expert_idx0 % p.ne11) * p.expert_stride_b0 +
- expert_idx1 * p.expert_stride_b1 +
-#endif
- batch_idx * p.batch_stride_b;
- const uint d_offset =
-#ifdef MUL_MAT_ID
- expert_idx0 * p.expert_stride_b0 +
- expert_idx1 * p.expert_stride_b1 +
-#endif
- batch_idx * p.batch_stride_d;
-
- const uint num_blocks_per_row = p.ncols / QUANT_K;
- const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
-
- const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
- const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
-
- const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
-
- const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
- const uint v_in = tid - step*v_im; // 0...15 or 0...7
-
- const uint8_t m = uint8_t(1 << (4 * v_im));
-
- const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15
- const uint q_offset = 32*v_im + l0;
- const uint y_offset = 128*v_im + l0;
-
- tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
-
- const uint s_shift = 4 * v_im;
-
- [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
- const uint y_idx = i * QUANT_K + y_offset;
-
- const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
-
- FLOAT_TYPE sum = FLOAT_TYPE(0.0);
- for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
- sum += FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[0] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4))
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[2] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4))
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[4] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4))
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[6] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4))
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[1] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4))
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[3] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4))
- + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[5] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4))
- + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[7] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4));
- }
- tmp[16 * ix + tid] += d * sum;
- }
-
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
- if (tid < s) {
- tmp[tid] += tmp[tid + s];
- }
- barrier();
- }
- if (tid == 0) {
- data_d[d_offset + row] = D_TYPE(tmp[0]);
- }
-}
-"""
-mul_mat_vec_q4_K_body = """
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
-
-shared FLOAT_TYPE tmp[32];
-
-void main() {
- const uint row = gl_WorkGroupID.x;
- const uint batch_idx = gl_GlobalInvocationID.y;
-#ifdef MUL_MAT_ID
- const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0;
- const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0;
-#endif
-
- const uint i13 = batch_idx / p.ne12;
- const uint i12 = batch_idx % p.ne12;
-
- const uint i03 = i13 / p.broadcast3;
- const uint i02 = i12 / p.broadcast2;
-
- const uint batch_idx_a = i03 * p.ne02 + i02;
-
-#ifdef MUL_MAT_ID
- const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0];
-#endif
-
- const uint a_offset =
-#ifdef MUL_MAT_ID
- expert_id * p.expert_stride_a +
-#endif
- batch_idx_a * p.batch_stride_a;
- const uint b_offset =
-#ifdef MUL_MAT_ID
- (expert_idx0 % p.ne11) * p.expert_stride_b0 +
- expert_idx1 * p.expert_stride_b1 +
-#endif
- batch_idx * p.batch_stride_b;
- const uint d_offset =
-#ifdef MUL_MAT_ID
- expert_idx0 * p.expert_stride_b0 +
- expert_idx1 * p.expert_stride_b1 +
-#endif
- batch_idx * p.batch_stride_d;
-
- const uint num_blocks_per_row = p.ncols / QUANT_K;
- const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
-
- const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
- const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
-
- const uint step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
-
- const uint il = tid/step; // 0...3
- const uint ir = tid - step*il; // 0...7 or 0...3
- const uint n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
-
- const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
- const uint v_in = il % 2;
-
- const uint l0 = n * (2 * ir + v_in); // 0...15
- const uint q_offset = 32*v_im + l0;
- const uint y_offset = 64*v_im + l0;
-
- tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
-
- [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
- const uint y1_idx = i * QUANT_K + y_offset;
- const uint y2_idx = y1_idx + 128;
-
- const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
-
- const uint8_t sc0 = uint8_t( data_a[ib0 + i].scales[v_im * 2 ] & 0x3f);
- const uint8_t sc1 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 1] & 0x3f);
- const uint8_t sc2 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 4] & 0x3f);
- const uint8_t sc3 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 5] & 0x3f);
- const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2));
- const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2));
- const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
- const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));
-
-#if K_QUANTS_PER_ITERATION == 2
- const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
- const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
- const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] & 0xf);
- const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 3] & 0xf);
- const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4);
- const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4);
- const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] >> 4);
- const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 3] >> 4);
- const uint8_t q4_8 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
- const uint8_t q4_9 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
- const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] & 0xf);
- const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] & 0xf);
- const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4);
- const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4);
- const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] >> 4);
- const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] >> 4);
-
- const FLOAT_TYPE sx = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx]) * q4_0 + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1 + FLOAT_TYPE(data_b[b_offset + y1_idx + 2]) * q4_2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * q4_3);
- const FLOAT_TYPE sy = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * q4_4 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_5 + FLOAT_TYPE(data_b[b_offset + y1_idx + 34]) * q4_6 + FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * q4_7);
- const FLOAT_TYPE sz = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx]) * q4_8 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_9 + FLOAT_TYPE(data_b[b_offset + y2_idx + 2]) * q4_10 + FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * q4_11);
- const FLOAT_TYPE sw = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * q4_12 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_13 + FLOAT_TYPE(data_b[b_offset + y2_idx + 34]) * q4_14 + FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * q4_15);
- const FLOAT_TYPE smin = FLOAT_TYPE(
- FLOAT_TYPE(data_b[b_offset + y1_idx ]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx ]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * sc7
- + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7
- + FLOAT_TYPE(data_b[b_offset + y1_idx + 2]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 34]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 2]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 34]) * sc7
- + FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * sc7
- );
- tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin);
-#else
- const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
- const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
- const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4);
- const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4);
- const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
- const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
- const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4);
- const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4);
-
- const FLOAT_TYPE sx = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx ]) * q4_0 + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1);
- const FLOAT_TYPE sy = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * q4_2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_3);
- const FLOAT_TYPE sz = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx ]) * q4_4 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_5);
- const FLOAT_TYPE sw = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * q4_6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_7);
- const FLOAT_TYPE smin = FLOAT_TYPE(
- FLOAT_TYPE(data_b[b_offset + y1_idx]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * sc7
- + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7
- );
-
- tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) + sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin);
-#endif
- }
-
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
- if (tid < s) {
- tmp[tid] += tmp[tid + s];
- }
- barrier();
- }
- if (tid == 0) {
- data_d[d_offset + row] = D_TYPE(tmp[0]);
- }
-}
-"""
-mul_mat_vec_q5_K_body = """
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
-
-shared FLOAT_TYPE tmp[32];
-
-void main() {
- const uint row = gl_WorkGroupID.x;
- const uint batch_idx = gl_GlobalInvocationID.y;
-#ifdef MUL_MAT_ID
- const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0;
- const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0;
-#endif
-
- const uint i13 = batch_idx / p.ne12;
- const uint i12 = batch_idx % p.ne12;
-
- const uint i03 = i13 / p.broadcast3;
- const uint i02 = i12 / p.broadcast2;
-
- const uint batch_idx_a = i03 * p.ne02 + i02;
-
-#ifdef MUL_MAT_ID
- const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0];
-#endif
-
- const uint a_offset =
-#ifdef MUL_MAT_ID
- expert_id * p.expert_stride_a +
-#endif
- batch_idx_a * p.batch_stride_a;
- const uint b_offset =
-#ifdef MUL_MAT_ID
- (expert_idx0 % p.ne11) * p.expert_stride_b0 +
- expert_idx1 * p.expert_stride_b1 +
-#endif
- batch_idx * p.batch_stride_b;
- const uint d_offset =
-#ifdef MUL_MAT_ID
- expert_idx0 * p.expert_stride_b0 +
- expert_idx1 * p.expert_stride_b1 +
-#endif
- batch_idx * p.batch_stride_d;
-
- const uint num_blocks_per_row = p.ncols / QUANT_K;
- const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
-
- const uint tid = gl_LocalInvocationID.x/2; // 0...31 or 0...16
- const uint ix = gl_LocalInvocationID.x%2; // 0 or 0, 1
-
- const uint il = tid/4; // 0...3
- const uint ir = tid - 4*il; // 0...7 or 0...3
-
- const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
- const uint v_in = il % 2;
-
- const uint l0 = 4*ir + 2*v_in; // 0...15
- const uint q_offset = 32*v_im + l0;
- const uint y_offset = 64*v_im + l0;
-
- const uint8_t hm1 = uint8_t(1 << (2*v_im));
- const uint8_t hm2 = uint8_t(hm1 << 4);
-
- tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
-
- [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
- const uint y1_idx = i * QUANT_K + y_offset;
- const uint y2_idx = y1_idx + 128;
-
- const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
-
- const uint8_t sc0 = uint8_t( data_a[ib0 + i].scales[v_im * 2 ] & 0x3f);
- const uint8_t sc1 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 1] & 0x3f);
- const uint8_t sc2 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 4] & 0x3f);
- const uint8_t sc3 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 5] & 0x3f);
- const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2));
- const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2));
- const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
- const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));
-
- const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
- const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
- const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 16] & 0xf);
- const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 17] & 0xf);
- const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4);
- const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4);
- const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 16] >> 4);
- const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 17] >> 4);
- const uint8_t q4_8 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
- const uint8_t q4_9 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
- const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] & 0xf);
- const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] & 0xf);
- const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4);
- const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4);
- const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] >> 4);
- const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] >> 4);
-
- const FLOAT_TYPE sx = FLOAT_TYPE(
- FLOAT_TYPE(data_b[b_offset + y1_idx ]) * (q4_0 + (((data_a[ib0 + i].qh[l0 ] & hm1) != 0) ? 16 : 0))
- + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * (q4_1 + (((data_a[ib0 + i].qh[l0 + 1] & hm1) != 0) ? 16 : 0))
- + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) * (q4_2 + (((data_a[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0))
- + FLOAT_TYPE(data_b[b_offset + y1_idx + 17]) * (q4_3 + (((data_a[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0))
- );
- const FLOAT_TYPE sy = FLOAT_TYPE(
- FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * (q4_4 + (((data_a[ib0 + i].qh[l0 ] & (hm1 << 1)) != 0) ? 16 : 0))
- + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * (q4_5 + (((data_a[ib0 + i].qh[l0 + 1] & (hm1 << 1)) != 0) ? 16 : 0))
- + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) * (q4_6 + (((data_a[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0))
- + FLOAT_TYPE(data_b[b_offset + y1_idx + 49]) * (q4_7 + (((data_a[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0))
- );
- const FLOAT_TYPE sz = FLOAT_TYPE(
- FLOAT_TYPE(data_b[b_offset + y2_idx ]) * (q4_8 + (((data_a[ib0 + i].qh[l0 ] & hm2) != 0) ? 16 : 0))
- + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * (q4_9 + (((data_a[ib0 + i].qh[l0 + 1] & hm2) != 0) ? 16 : 0))
- + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) * (q4_10 + (((data_a[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0))
- + FLOAT_TYPE(data_b[b_offset + y2_idx + 17]) * (q4_11 + (((data_a[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0))
- );
- const FLOAT_TYPE sw = FLOAT_TYPE(
- FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * (q4_12 + (((data_a[ib0 + i].qh[l0 ] & (hm2 << 1)) != 0) ? 16 : 0))
- + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * (q4_13 + (((data_a[ib0 + i].qh[l0 + 1] & (hm2 << 1)) != 0) ? 16 : 0))
- + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) * (q4_14 + (((data_a[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0))
- + FLOAT_TYPE(data_b[b_offset + y2_idx + 49]) * (q4_15 + (((data_a[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0))
- );
- const FLOAT_TYPE smin = FLOAT_TYPE(
- (FLOAT_TYPE(data_b[b_offset + y1_idx]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 17])) * sc2 + (FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 49])) * sc3
- + (FLOAT_TYPE(data_b[b_offset + y2_idx]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 17])) * sc6 + (FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 49])) * sc7
- );
- tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin);
- }
-
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
- if (tid < s) {
- tmp[tid] += tmp[tid + s];
- }
- barrier();
- }
- if (tid == 0) {
- data_d[d_offset + row] = D_TYPE(tmp[0]);
- }
-}
-"""
-mul_mat_vec_q6_K_body = """
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
-
-shared FLOAT_TYPE tmp[32];
-
-void main() {
- const uint row = gl_WorkGroupID.x;
- const uint batch_idx = gl_GlobalInvocationID.y;
-#ifdef MUL_MAT_ID
- const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0;
- const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0;
-#endif
-
- const uint i13 = batch_idx / p.ne12;
- const uint i12 = batch_idx % p.ne12;
-
- const uint i03 = i13 / p.broadcast3;
- const uint i02 = i12 / p.broadcast2;
-
- const uint batch_idx_a = i03 * p.ne02 + i02;
-
-#ifdef MUL_MAT_ID
- const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0];
-#endif
-
- const uint a_offset =
-#ifdef MUL_MAT_ID
- expert_id * p.expert_stride_a +
-#endif
- batch_idx_a * p.batch_stride_a;
- const uint b_offset =
-#ifdef MUL_MAT_ID
- (expert_idx0 % p.ne11) * p.expert_stride_b0 +
- expert_idx1 * p.expert_stride_b1 +
-#endif
- batch_idx * p.batch_stride_b;
- const uint d_offset =
-#ifdef MUL_MAT_ID
- expert_idx0 * p.expert_stride_b0 +
- expert_idx1 * p.expert_stride_b1 +
-#endif
- batch_idx * p.batch_stride_d;
-
- const uint num_blocks_per_row = p.ncols / QUANT_K;
- const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
-
- const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
- const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
-
- const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
-
- const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
- const uint v_in = tid - step*v_im; // 0...15 or 0...7
-
-#if K_QUANTS_PER_ITERATION == 1
- const uint l0 = v_in; // 0...15
- const uint is = 0;
-#else
- const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28
- const uint is = v_in / 4;
-#endif
-
- const uint ql_offset = 64*v_im + l0;
- const uint qh_offset = 32*v_im + l0;
- const uint s_offset = 8*v_im + is;
- const uint y_offset = 128*v_im + l0;
-
- tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
-
- [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
- const uint y_idx = i * QUANT_K + y_offset;
-
- const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
-
-#if K_QUANTS_PER_ITERATION == 1
- FLOAT_TYPE sum = FLOAT_TYPE(data_b[b_offset + y_idx + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32)
- + FLOAT_TYPE(data_b[b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32)
- + FLOAT_TYPE(data_b[b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32)
- + FLOAT_TYPE(data_b[b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32)
- + FLOAT_TYPE(data_b[b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32)
- + FLOAT_TYPE(data_b[b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32)
- + FLOAT_TYPE(data_b[b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32)
- + FLOAT_TYPE(data_b[b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32);
- tmp[16 * ix + tid] += sum;
-#else
- FLOAT_TYPE sum = FLOAT_TYPE(0.0);
- [[unroll]] for (int l = 0; l < 4; ++l) {
- sum += FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l+32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l+64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32)
- + FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32);
- }
- tmp[16 * ix + tid] += sum;
-#endif
- }
-
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
- if (tid < s) {
- tmp[tid] += tmp[tid + s];
- }
- barrier();
- }
- if (tid == 0) {
- data_d[d_offset + row] = D_TYPE(tmp[0]);
- }
-}
-"""
-
-mul_mat_p021_src = """#version 450
-
-#extension GL_EXT_control_flow_attributes : enable
-#extension GL_EXT_shader_16bit_storage : require
-
-#define BLOCK_SIZE 32
-#define FLOAT_TYPE float
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
-
-layout (push_constant) uniform parameter
-{
- uint ncols_x;
- uint nrows_x;
- uint nchannels_x;
- uint nchannels_y;
- uint b_offset;
- uint d_offset;
-} p;
-
-shared FLOAT_TYPE tmp[BLOCK_SIZE];
-
-void main() {
- const uint tid = gl_LocalInvocationID.x;
- const uint row_x = gl_GlobalInvocationID.y;
- const uint channel = gl_GlobalInvocationID.z;
- const uint channel_x = channel / (p.nchannels_y / p.nchannels_x);
-
- const uint nrows_y = p.ncols_x;
- const uint nrows_dst = p.nrows_x;
- const uint row_dst = row_x;
-
- tmp[tid] = FLOAT_TYPE(0.0f);
-
- for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
- const uint col_x = col_x0 + tid;
-
- if (col_x >= p.ncols_x) {
- break;
- }
-
- // x is transposed and permuted
- const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
- const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
-
- const uint row_y = col_x;
-
- // y is not transposed but permuted
- const uint iy = channel*nrows_y + row_y;
-
- tmp[tid] += xi * FLOAT_TYPE(data_b[iy]);
- }
-
- // dst is not transposed and not permuted
- const uint idst = channel*nrows_dst + row_dst;
-
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
- if (tid < s) {
- tmp[tid] += tmp[tid + s];
- }
- barrier();
- }
-
- if (tid == 0) {
- dst[idst] = tmp[0];
- }
-}
-"""
-
-
-mul_mat_nc_src = """#version 450
-
-#extension GL_EXT_control_flow_attributes : enable
-#extension GL_EXT_shader_16bit_storage : require
-
-#define BLOCK_SIZE 32
-#define FLOAT_TYPE float
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
-
-layout (push_constant) uniform parameter
-{
- uint ncols_x;
- uint nrows_x;
- uint row_stride_x;
- uint channel_stride_x;
- uint channel_x_divisor;
- uint b_offset;
- uint d_offset;
-} p;
-
-shared FLOAT_TYPE tmp[BLOCK_SIZE];
-
-void main() {
- const uint tid = gl_LocalInvocationID.x;
- const uint row_x = gl_GlobalInvocationID.y;
- const uint channel = gl_GlobalInvocationID.z;
- const uint channel_x = channel / p.channel_x_divisor;
-
- const uint nrows_y = p.ncols_x;
- const uint nrows_dst = p.nrows_x;
- const uint row_dst = row_x;
-
- const uint idst = channel*nrows_dst + row_dst;
-
- tmp[tid] = 0.0f;
-
- for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
- const uint col_x = col_x0 + tid;
-
- if (col_x >= p.ncols_x) {
- break;
- }
-
- const uint row_y = col_x;
-
- const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
- const uint iy = channel*nrows_y + row_y;
-
- const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
-
- tmp[tid] += xi * FLOAT_TYPE(data_b[iy]);
- }
-
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
- if (tid < s) {
- tmp[tid] += tmp[tid + s];
- }
- barrier();
- }
-
- if (tid == 0) {
- dst[idst] = tmp[0];
- }
-}
-"""
-
-generic_head = """
-#version 450
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout (push_constant) uniform parameter
-{
- uint KX;
- uint KY;
- float param1;
- float param2;
-} p;
-"""
-
-generic_unary_op_head = """#version 450
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout (push_constant) uniform parameter
-{
- uint ne;
- uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
- uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
- uint d_offset;
- float param1; float param2;
-} p;"""
-
-generic_unary_op_layout = """
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};"""
-
-generic_unary_op_funcs = """
-uint src0_idx(uint idx) {
- const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
- const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
- const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
- const uint i02_offset = i02*p.ne01*p.ne00;
- const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
- const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
- return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
-}
-
-uint dst_idx(uint idx) {
- const uint i13 = idx / (p.ne12*p.ne11*p.ne10);
- const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
- const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10);
- const uint i12_offset = i12*p.ne11*p.ne10;
- const uint i11 = (idx - i13_offset - i12_offset) / p.ne10;
- const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
- return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;
-}"""
-
-generic_unary_op_main = """
-void main() {
- if (gl_GlobalInvocationID.x >= p.ne) {
- return;
- }
-"""
-
-generic_unary_op_combined = f"{generic_unary_op_head}\n{generic_unary_op_layout}\n{generic_unary_op_funcs}\n{generic_unary_op_main}"
-
-generic_binary_op_head = """#version 450
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout (push_constant) uniform parameter
-{
- uint ne;
- uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
- uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
- uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
- uint d_offset;
- float param1; float param2;
-} p;"""
-
-generic_binary_op_layout = """
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};"""
-
-generic_binary_op_funcs = """
-uint src0_idx(uint idx) {
- const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
- const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
- const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
- const uint i02_offset = i02*p.ne01*p.ne00;
- const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
- const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
- return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
-}
-
-uint src1_idx(uint idx) {
- const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
- const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
- const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
- const uint i02_offset = i02*p.ne01*p.ne00;
- const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
- const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
-
- return (i03 % p.ne13)*p.nb13 + (i02 % p.ne12)*p.nb12 + (i01 % p.ne11)*p.nb11 + (i00 % p.ne10)*p.nb10;
-}
-
-uint dst_idx(uint idx) {
- const uint i23 = idx / (p.ne22*p.ne21*p.ne20);
- const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20;
- const uint i22 = (idx - i23_offset) / (p.ne21*p.ne20);
- const uint i22_offset = i22*p.ne21*p.ne20;
- const uint i21 = (idx - i23_offset - i22_offset) / p.ne20;
- const uint i20 = idx - i23_offset - i22_offset - i21*p.ne20;
- return i23*p.nb23 + i22*p.nb22 + i21*p.nb21 + i20*p.nb20;
-}"""
-
-generic_binary_op_main = """
-void main() {
- if (gl_GlobalInvocationID.x >= p.ne) {
- return;
- }
-"""
-
-generic_binary_op_combined = f"{generic_binary_op_head}\n{generic_binary_op_layout}\n{generic_binary_op_funcs}\n{generic_binary_op_main}"
-
-# MUL F32
-mul_body = """
- data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) * FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)]));
-}
-"""
-
-# ADD
-add_body = """
- data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) + FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)]));
-}
-"""
-
-# SCALE
-scale_body = """
- data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) * FLOAT_TYPE(p.param1));
-}
-"""
-
-# SQR
-sqr_body = """
- const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]);
- data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(val * val);
-}
-"""
-
-# CLAMP
-clamp_body = """
- const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]);
- data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
-}
-"""
-
-# CPY
-cpy_end = """
- data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]);
-}
-"""
-# Causes an optimization error otherwise
-cpy_f16_f16_end = """
- data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = data_a[src0_idx(gl_GlobalInvocationID.x)];
-}
-"""
-
-# GET_ROWS
-get_rows_float_body = """
-void main() {
- const uint i00 = gl_GlobalInvocationID.x;
- const uint i10 = gl_GlobalInvocationID.y;
- const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
- const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
-
- if (i00 >= p.ne00) {
- return;
- }
-
- const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
-
- const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
- const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
-
-#ifndef OPTIMIZATION_ERROR_WORKAROUND
- data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]);
-#else
- data_d[d_offset + i00] = data_a[a_offset + i00];
-#endif
-}
-"""
-
-get_rows_body = """
-void main() {
- const uint i00 = (gl_GlobalInvocationID.x)*2;
- const uint i10 = gl_GlobalInvocationID.y;
- const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
- const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
-
- if (i00 >= p.ne00) {
- return;
- }
-
- const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
-
- const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
- const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
-
- const uint ib = a_offset + i00/QUANT_K; // block index
- const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index
- const uint iybs = i00 - i00%QUANT_K; // dst block start index
- const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
-
- vec2 v = dequantize(ib, iqs, 0);
-
- data_d[d_offset + iybs + iqs ] = D_TYPE(v.x);
- data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y);
-}
-"""
-
-# UNARY
-gelu_body = """
-#extension GL_EXT_control_flow_attributes : enable
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
- const float GELU_COEF_A = 0.044715f;
- const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
- const uint i = gl_GlobalInvocationID.x;
-
- if (i >= p.KX) {
- return;
- }
-
- const float xi = float(data_a[i]);
- const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi);
- data_d[i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1)));
-}
-"""
-
-silu_body = """
-#extension GL_EXT_control_flow_attributes : enable
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
- const uint i = gl_GlobalInvocationID.x;
-
- if (i >= p.KX) {
- return;
- }
-
- const float xi = float(data_a[i]);
- data_d[i] = D_TYPE(xi / (1.0f + exp(-xi)));
-}
-"""
-
-relu_body = """
-#extension GL_EXT_control_flow_attributes : enable
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
- const uint i = gl_GlobalInvocationID.x;
-
- if (i >= p.KX) {
- return;
- }
-
- data_d[i] = max(float(data_a[i]), 0);
-}
-"""
-
-# DIAG_MASK_INF
-diag_mask_inf_head = """#version 450
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout (push_constant) uniform parameter
-{
- uint ncols;
- uint rows_per_channel;
- uint n_past;
-} p;
-"""
-diag_mask_inf_body = """
-#extension GL_EXT_control_flow_attributes : enable
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
- const uint col = gl_GlobalInvocationID.y;
- const uint row = gl_GlobalInvocationID.x;
-
- if (col >= p.ncols) {
- return;
- }
-
- const uint i = row*p.ncols + col;
- if (col > p.n_past + row % p.rows_per_channel) {
- data_d[i] = D_TYPE(uintBitsToFloat(0xFF800000));
- } else {
- data_d[i] = D_TYPE(data_a[i]);
- }
-}
-"""
-
-# NORMS
-norm_body = """
-#extension GL_EXT_control_flow_attributes : enable
-#define BLOCK_SIZE 512
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-shared vec2 sum[BLOCK_SIZE];
-
-void main() {
- const uint row = gl_WorkGroupID.x;
- const uint tid = gl_LocalInvocationID.x;
-
- sum[tid] = vec2(0.0f, 0.0f);
-
- [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
- const float xi = float(data_a[row*p.KX + col]);
- sum[tid].x += xi;
- sum[tid].y += xi * xi;
- }
-
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
- if (tid < s) {
- sum[tid] += sum[tid + s];
- }
- barrier();
- }
-
- const float mean = sum[0].x / p.KX;
- const float var = sum[0].y / p.KX - mean * mean;
- const float inv_std = inversesqrt(var + p.param1);
-
- [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
- data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std);
- }
-}
-"""
-
-rms_norm_body = """
-#extension GL_EXT_control_flow_attributes : enable
-#define BLOCK_SIZE 512
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-shared FLOAT_TYPE sum[BLOCK_SIZE];
-
-void main() {
- const uint row = gl_WorkGroupID.x;
- const uint tid = gl_LocalInvocationID.x;
-
- sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
-
- [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
- const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
- sum[tid] += xi * xi;
- }
-
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
- if (tid < s) {
- sum[tid] += sum[tid + s];
- }
- barrier();
- }
-
- const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(p.KX);
- const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
-
- [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
- data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
- }
-}
-"""
-
-# SOFT_MAX
-soft_max_head = """
-#version 450
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout (push_constant) uniform parameter
-{
- uint KX;
- uint KY;
- uint KZ;
- float scale;
- float max_bias;
- float m0;
- float m1;
- uint n_head_log2;
-} p;
-"""
-
-soft_max_body = """
-#extension GL_EXT_control_flow_attributes : enable
-#define BLOCK_SIZE 512
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
-layout (binding = 2) readonly buffer Z {C_TYPE data_c[];};
-layout (binding = 3) buffer D {D_TYPE data_d[];};
-
-shared FLOAT_TYPE vals[BLOCK_SIZE];
-
-void main() {
- const uint tid = gl_LocalInvocationID.x;
- const uint rowx = gl_WorkGroupID.x;
- const uint rowy = rowx % p.KY;
-
- float slope = 0.0f;
-
- // ALiBi
- if (p.max_bias > 0.0f) {
- const uint h = rowx/p.KY; // head index
-
- const float base = h < p.n_head_log2 ? p.m0 : p.m1;
- const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
-
- slope = pow(base, exp);
- }
-
- // Find max
- vals[tid] = uintBitsToFloat(0xFF800000);
-
- [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
- vals[tid] = max(vals[tid], FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.scale + (p.KY > 0 ? FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) + (p.KZ > 0 ? slope * FLOAT_TYPE(data_c[col]) : 0.0f));
- }
-
- barrier();
- [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
- if (tid < s) {
- vals[tid] = max(vals[tid], vals[tid + s]);
- }
- barrier();
- }
-
- const FLOAT_TYPE max_val = vals[0];
- barrier();
-
- // Sum up values
- vals[tid] = FLOAT_TYPE(0.0f);
-
- [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
- const uint i = rowx * p.KX + col;
- const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
- vals[tid] += val;
- data_d[i] = D_TYPE(val);
- }
-
- barrier();
- [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
- if (tid < s) {
- vals[tid] += vals[tid + s];
- }
- barrier();
- }
-
- const D_TYPE divisor = D_TYPE(vals[0]);
-
- [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
- data_d[rowx*p.KX + col] /= divisor;
- }
-}
-"""
-
-# ROPE
-rope_src = """
-#version 450
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer Y {int data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
-
-layout (push_constant) uniform parameter {
- uint ncols;
- float freq_scale;
- uint p_delta_rows;
- float freq_base;
- float ext_factor;
- float attn_factor;
- float corr_dims[4];
-} p;
-
-float rope_yarn_ramp(const float low, const float high, const uint i0) {
- const float y = (i0 / 2 - low) / max(0.001f, high - low);
- return 1.0f - min(1.0f, max(0.0f, y));
-}
-
-void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta) {
- float mscale = p.attn_factor;
- // Get n-d rotational scaling corrected for extrapolation
- float theta_interp = p.freq_scale * theta_extrap;
- float theta = theta_interp;
- if (p.ext_factor != 0.0f) {
- float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor;
- theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
-
- // Get n-d magnitude scaling corrected for interpolation
- mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
- }
- cos_theta = cos(theta) * mscale;
- sin_theta = sin(theta) * mscale;
-}
-
-void main() {
- const uint col = gl_GlobalInvocationID.y * 2;
- const uint row = gl_GlobalInvocationID.x;
-
- if (col >= p.ncols) {
- return;
- }
-
- const uint i = row*p.ncols + col;
- const uint i2 = row/p.p_delta_rows;
-
- const int pos = data_b[i2];
- const float theta_base = pos * pow(p.freq_base, -float(col)/p.ncols);
-
- float cos_theta, sin_theta;
- rope_yarn(theta_base, col, cos_theta, sin_theta);
-
- const float x0 = float(data_a[i + 0]);
- const float x1 = float(data_a[i + 1]);
-
- data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
- data_d[i + 1] = D_TYPE(x0*sin_theta + x1*cos_theta);
-}
-"""
-
-rope_neox_src = """
-#version 450
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer Y {int data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
-
-layout (push_constant) uniform parameter {
- uint ncols;
- uint ndims;
- float freq_scale;
- uint p_delta_rows;
- float freq_base;
- float ext_factor;
- float attn_factor;
- float corr_dims[4];
- float theta_scale;
- float inv_ndims;
-} p;
-
-float rope_yarn_ramp(const float low, const float high, const uint i0) {
- const float y = (i0 / 2 - low) / max(0.001f, high - low);
- return 1.0f - min(1.0f, max(0.0f, y));
-}
-
-void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta) {
- float mscale = p.attn_factor;
- // Get n-d rotational scaling corrected for extrapolation
- float theta_interp = p.freq_scale * theta_extrap;
- float theta = theta_interp;
- if (p.ext_factor != 0.0f) {
- float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor;
- theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
-
- // Get n-d magnitude scaling corrected for interpolation
- mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
- }
- cos_theta = cos(theta) * mscale;
- sin_theta = sin(theta) * mscale;
-}
-
-void main() {
- const uint col = gl_GlobalInvocationID.y * 2;
- const uint row = gl_GlobalInvocationID.x;
-
- if (col >= p.ncols) {
- return;
- }
-
- const uint ib = col / p.ndims;
- const uint ic = col % p.ndims;
-
- if (ib > 0) {
- const uint i = row*p.ncols + ib*p.ndims + ic;
-
- data_d[i + 0] = data_a[i + 0];
- data_d[i + 1] = data_a[i + 1];
-
- return;
- }
-
- const uint i = row*p.ncols + ib*p.ndims + ic/2;
- const uint i2 = row/p.p_delta_rows;
-
- const float cur_rot = p.inv_ndims * ic - ib;
-
- const int pos = data_b[i2];
- const float theta_base = pos*p.freq_scale*pow(p.theta_scale, col/2.0f);
-
- float cos_theta, sin_theta;
- rope_yarn(theta_base, uint(cur_rot), cos_theta, sin_theta);
-
- const float x0 = float(data_a[i + 0]);
- const float x1 = float(data_a[i + p.ndims/2]);
-
- data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
- data_d[i + p.ndims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
-}
-"""
-
-argsort_src = """
-#version 450
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout(local_size_x = 1024, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) buffer D {int data_d[];};
-
-layout (push_constant) uniform parameter {
- uint ncols;
- bool ascending;
-} p;
-
-void swap(uint idx0, uint idx1) {
- int tmp = data_d[idx0];
- data_d[idx0] = data_d[idx1];
- data_d[idx1] = tmp;
-}
-
-void main() {
- // bitonic sort
- const int col = int(gl_LocalInvocationID.x);
- const uint row = gl_WorkGroupID.y;
-
- if (col >= p.ncols) {
- return;
- }
-
- const uint a_idx = row * p.ncols;
- const uint d_idx = row * p.ncols;
-
- // initialize indices
- if (col < p.ncols) {
- data_d[col] = col;
- }
- barrier();
-
- for (uint k = 2; k <= p.ncols; k *= 2) {
- for (uint j = k / 2; j > 0; j /= 2) {
- const uint ixj = col ^ j;
- if (ixj > col) {
- if ((col & k) == 0) {
- if (p.ascending ? data_a[a_idx + data_d[d_idx + col]] > data_a[a_idx + data_d[d_idx + ixj]] : data_a[a_idx + data_d[d_idx + col]] < data_a[a_idx + data_d[d_idx + ixj]]) {
- swap(d_idx + col, d_idx + ixj);
- }
- } else {
- if (p.ascending ? data_a[a_idx + data_d[d_idx + col]] < data_a[a_idx + data_d[d_idx + ixj]] : data_a[a_idx + data_d[d_idx + col]] > data_a[a_idx + data_d[d_idx + ixj]]) {
- swap(d_idx + col, d_idx + ixj);
- }
- }
- }
- barrier();
- }
- }
-}
-"""
-
-GLSLC = "glslc"
-
-VK_NUM_TYPES = 16
-
-GGML_TYPE_F32 = 0
-GGML_TYPE_F16 = 1
-GGML_TYPE_Q4_0 = 2
-GGML_TYPE_Q4_1 = 3
-GGML_TYPE_Q5_0 = 6
-GGML_TYPE_Q5_1 = 7
-GGML_TYPE_Q8_0 = 8
-GGML_TYPE_Q8_1 = 9
-GGML_TYPE_Q2_K = 10
-GGML_TYPE_Q3_K = 11
-GGML_TYPE_Q4_K = 12
-GGML_TYPE_Q5_K = 13
-GGML_TYPE_Q6_K = 14
-GGML_TYPE_Q8_K = 15
-
-
-type_names = {
- GGML_TYPE_F32: "f32",
- GGML_TYPE_F16: "f16",
- GGML_TYPE_Q4_0: "q4_0",
- GGML_TYPE_Q4_1: "q4_1",
- GGML_TYPE_Q5_0: "q5_0",
- GGML_TYPE_Q5_1: "q5_1",
- GGML_TYPE_Q8_0: "q8_0",
- GGML_TYPE_Q8_1: "q8_1",
- GGML_TYPE_Q2_K: "q2_K",
- GGML_TYPE_Q3_K: "q3_K",
- GGML_TYPE_Q4_K: "q4_K",
- GGML_TYPE_Q5_K: "q5_K",
- GGML_TYPE_Q6_K: "q6_K",
- GGML_TYPE_Q8_K: "q8_K",
-}
-
-K_QUANTS_PER_ITERATION = 2
-
-ASYNCIO_CONCURRENCY = 64
-
-output_dir = gettempdir()
-
-lock = asyncio.Lock()
-shader_fnames = []
-
-
-async def string_to_spv(name, code, defines, fp16=True):
- f = NamedTemporaryFile(mode="w", delete=False)
- f.write(code)
- f.flush()
-
- name = f"{name}{'_fp32' if not fp16 else ''}"
- fname = os.path.join(output_dir, f"{name}.comp")
-
- cmd = [GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", f.name, "-o", fname]
-
- cmd.extend([f"-D{key}={value}" for key, value in defines.items()])
-
- proc = await asyncio.create_subprocess_exec(*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
-
- stdout, stderr = await proc.communicate()
-
- stdout = stdout.decode()
- error = stderr.decode()
-
- if proc.returncode:
- # Generate preprocessed code
- cmd = [GLSLC, "-E", f.name]
- cmd.extend([f"-D{key}={value}" for key, value in defines.items()])
-
- proc = await asyncio.create_subprocess_exec(*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
-
- stdout, stderr = await proc.communicate()
-
- logger.info(" ".join(cmd))
-
- if proc.returncode:
- raise RuntimeError(f"{name=} {f.name=} {stdout=} {stderr=}")
-
- preprocessed_code = stdout.decode()
-
- cmd.extend([f"-D{key}={value}" for key, value in defines.items()])
- code_with_lines = "\n".join([f"{i + 1}: {line}" for i, line in enumerate(preprocessed_code.splitlines())])
- logger.error(f"cannot compile {name}\n\n{code_with_lines}\n\n{error}")
- f.close()
- os.remove(f.name)
- sys.exit(proc.returncode)
-
- f.close()
- os.remove(f.name)
-
- async with lock:
- shader_fnames.append((name, fname))
-
-
-async def main():
- logger.info("ggml_vulkan: Generating and compiling shaders to SPIR-V")
-
- tasks = []
-
- stream = []
-
- for fp16 in (False, True):
- # mulmat
- if fp16:
- shader_float_type = shader_f16
- load_vec = "8"
- vec_type_f16 = "f16mat2x4"
- vec_type = "mat2x4"
- else:
- shader_float_type = shader_f32
- load_vec = "4"
- vec_type_f16 = "f16vec4"
- vec_type = "vec4"
-
- stream.clear()
- stream.extend((mulmat_head, shader_float_type, mulmat_body1, mulmat_load_scalar, mulmat_body2))
- tasks.append(string_to_spv("matmul_f32", "".join(stream), {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_f32_aligned", "".join(stream), {"LOAD_VEC_A": 1, "LOAD_VEC_B": load_vec, "A_TYPE": "float", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
- tasks.append(string_to_spv("matmul_f16", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_f16_aligned", "".join(stream), {"LOAD_VEC_A": 1, "LOAD_VEC_B": load_vec, "A_TYPE": "float16_t", "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16))
-
- tasks.append(string_to_spv("matmul_f16_f32", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_f16_f32_aligned", "".join(stream), {"LOAD_VEC_A": 1, "LOAD_VEC_B": load_vec, "A_TYPE": "float16_t", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
- stream.clear()
- stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_0_defines, mulmat_body1, mulmat_load_q4_0, mulmat_body2))
- tasks.append(string_to_spv("matmul_q4_0_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q4_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_q4_0_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
- stream.clear()
- stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_1_defines, mulmat_body1, mulmat_load_q4_1, mulmat_body2))
- tasks.append(string_to_spv("matmul_q4_1_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q4_1", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_q4_1_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_1", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
- stream.clear()
- stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_0_defines, mulmat_body1, mulmat_load_q5_0, mulmat_body2))
- tasks.append(string_to_spv("matmul_q5_0_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q5_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_q5_0_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
- stream.clear()
- stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_1_defines, mulmat_body1, mulmat_load_q5_1, mulmat_body2))
- tasks.append(string_to_spv("matmul_q5_1_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q5_1", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_q5_1_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_1", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
- stream.clear()
- stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q8_0_defines, mulmat_body1, mulmat_load_q8_0, mulmat_body2))
- tasks.append(string_to_spv("matmul_q8_0_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q8_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_q8_0_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q8_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
- stream.clear()
- stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q2_K_defines, mulmat_body1, mulmat_load_q2_K, mulmat_body2))
- tasks.append(string_to_spv("matmul_q2_k_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q2_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_q2_k_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q2_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
- stream.clear()
- stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q3_K_defines, mulmat_body1, mulmat_load_q3_K, mulmat_body2))
- tasks.append(string_to_spv("matmul_q3_k_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q3_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_q3_k_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q3_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
- stream.clear()
- stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_K_defines, mulmat_body1, mulmat_load_q4_K, mulmat_body2))
- tasks.append(string_to_spv("matmul_q4_k_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q4_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_q4_k_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
- stream.clear()
- stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_K_defines, mulmat_body1, mulmat_load_q5_K, mulmat_body2))
- tasks.append(string_to_spv("matmul_q5_k_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q5_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_q5_k_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
- stream.clear()
- stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q6_K_defines, mulmat_body1, mulmat_load_q6_K, mulmat_body2))
- tasks.append(string_to_spv("matmul_q6_k_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q6_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_q6_k_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q6_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
- # MUL_MAT_ID
- # stream.clear()
- # stream.extend((mulmat_head, shader_float_type, mulmat_body1, mulmat_load_scalar, mulmat_body2))
- # tasks.append(string_to_spv("matmul_id_f32", "".join(stream), {"MUL_MAT_ID": "1", "A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- # tasks.append(string_to_spv("matmul_id_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
- # tasks.append(string_to_spv("matmul_id_f16", "".join(stream), {"MUL_MAT_ID": "1", "A_TYPE": "float16_t", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16))
- # tasks.append(string_to_spv("matmul_id_f16_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16))
-
- # tasks.append(string_to_spv("matmul_id_f16_f32", "".join(stream), {"MUL_MAT_ID": "1", "A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- # tasks.append(string_to_spv("matmul_id_f16_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
- # stream.clear()
- # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_0_defines, mulmat_body1, mulmat_load_q4_0, mulmat_body2))
- # tasks.append(string_to_spv("matmul_id_q4_0_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q4_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- # tasks.append(string_to_spv("matmul_id_q4_0_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
- # stream.clear()
- # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_1_defines, mulmat_body1, mulmat_load_q4_1, mulmat_body2))
- # tasks.append(string_to_spv("matmul_id_q4_1_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q4_1", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- # tasks.append(string_to_spv("matmul_id_q4_1_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_1", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
- # stream.clear()
- # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_0_defines, mulmat_body1, mulmat_load_q5_0, mulmat_body2))
- # tasks.append(string_to_spv("matmul_id_q5_0_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q5_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- # tasks.append(string_to_spv("matmul_id_q5_0_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
- # stream.clear()
- # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_1_defines, mulmat_body1, mulmat_load_q5_1, mulmat_body2))
- # tasks.append(string_to_spv("matmul_id_q5_1_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q5_1", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- # tasks.append(string_to_spv("matmul_id_q5_1_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_1", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
- # stream.clear()
- # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q8_0_defines, mulmat_body1, mulmat_load_q8_0, mulmat_body2))
- # tasks.append(string_to_spv("matmul_id_q8_0_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q8_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- # tasks.append(string_to_spv("matmul_id_q8_0_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q8_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
- # stream.clear()
- # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q2_K_defines, mulmat_body1, mulmat_load_q2_K, mulmat_body2))
- # tasks.append(string_to_spv("matmul_id_q2_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q2_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- # tasks.append(string_to_spv("matmul_id_q2_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q2_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
- # stream.clear()
- # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q3_K_defines, mulmat_body1, mulmat_load_q3_K, mulmat_body2))
- # tasks.append(string_to_spv("matmul_id_q3_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q3_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- # tasks.append(string_to_spv("matmul_id_q3_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q3_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
- # stream.clear()
- # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_K_defines, mulmat_body1, mulmat_load_q4_K, mulmat_body2))
- # tasks.append(string_to_spv("matmul_id_q4_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q4_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- # tasks.append(string_to_spv("matmul_id_q4_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
- # stream.clear()
- # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_K_defines, mulmat_body1, mulmat_load_q5_K, mulmat_body2))
- # tasks.append(string_to_spv("matmul_id_q5_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q5_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- # tasks.append(string_to_spv("matmul_id_q5_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
- # stream.clear()
- # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q6_K_defines, mulmat_body1, mulmat_load_q6_K, mulmat_body2))
- # tasks.append(string_to_spv("matmul_id_q6_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q6_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- # tasks.append(string_to_spv("matmul_id_q6_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q6_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
- # Shaders where precision is needed, so no fp16 version
-
- # mul mat vec
- for i in range(0, VK_NUM_TYPES):
- stream.clear()
- stream.extend((mul_mat_vec_head, shader_int8_ext, shader_f32))
-
- if i == GGML_TYPE_F16:
- stream.extend((shader_f16_defines, mul_mat_vec_layout, shader_float_dequant_func, mul_mat_vec_body))
- elif i == GGML_TYPE_Q4_0:
- stream.extend((shader_q4_0_defines, mul_mat_vec_layout, shader_q4_0_dequant_func, mul_mat_vec_body))
- elif i == GGML_TYPE_Q4_1:
- stream.extend((shader_q4_1_defines, mul_mat_vec_layout, shader_q4_1_dequant_func, mul_mat_vec_body))
- elif i == GGML_TYPE_Q5_0:
- stream.extend((shader_q5_0_defines, mul_mat_vec_layout, shader_q5_0_dequant_func, mul_mat_vec_body))
- elif i == GGML_TYPE_Q5_1:
- stream.extend((shader_q5_1_defines, mul_mat_vec_layout, shader_q5_1_dequant_func, mul_mat_vec_body))
- elif i == GGML_TYPE_Q8_0:
- stream.extend((shader_q8_0_defines, mul_mat_vec_layout, shader_q8_0_dequant_func, mul_mat_vec_body))
- elif i == GGML_TYPE_Q2_K:
- stream.extend((shader_q2_K_defines, mul_mat_vec_layout, mul_mat_vec_q2_K_body))
- elif i == GGML_TYPE_Q3_K:
- stream.extend((shader_q3_K_defines, mul_mat_vec_layout, mul_mat_vec_q3_K_body))
- elif i == GGML_TYPE_Q4_K:
- stream.extend((shader_q4_K_defines, mul_mat_vec_layout, mul_mat_vec_q4_K_body))
- elif i == GGML_TYPE_Q5_K:
- stream.extend((shader_q5_K_defines, mul_mat_vec_layout, mul_mat_vec_q5_K_body))
- elif i == GGML_TYPE_Q6_K:
- stream.extend((shader_q6_K_defines, mul_mat_vec_layout, mul_mat_vec_q6_K_body))
- else:
- continue
-
- tasks.append(string_to_spv(f"mul_mat_vec_{type_names[i]}_f32_f32", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION}))
- tasks.append(string_to_spv(f"mul_mat_vec_{type_names[i]}_f16_f32", "".join(stream), {"B_TYPE": "float16_t", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION}))
-
- # tasks.append(string_to_spv(f"mul_mat_vec_id_{type_names[i]}_f32", "".join(stream), {"MUL_MAT_ID": "1", "B_TYPE": "float", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION}))
-
- # Dequant shaders
- for i in range(0, VK_NUM_TYPES):
- stream.clear()
-
- stream.extend((dequant_head, shader_int8_ext, shader_f32))
-
- if i == GGML_TYPE_F32:
- stream.append(dequant_f32_body)
- elif i == GGML_TYPE_Q4_0:
- stream.extend((shader_q4_0_defines, dequant_q4_0_body))
- elif i == GGML_TYPE_Q4_1:
- stream.extend((shader_q4_1_defines, dequant_q4_1_body))
- elif i == GGML_TYPE_Q5_0:
- stream.extend((shader_q5_0_defines, dequant_q5_0_body))
- elif i == GGML_TYPE_Q5_1:
- stream.extend((shader_q5_1_defines, dequant_q5_1_body))
- elif i == GGML_TYPE_Q8_0:
- stream.extend((shader_q8_0_defines, dequant_q8_0_body))
- elif i == GGML_TYPE_Q2_K:
- stream.extend((shader_q2_K_defines, dequant_q2_K_body))
- elif i == GGML_TYPE_Q3_K:
- stream.extend((shader_q3_K_defines, dequant_q3_K_body))
- elif i == GGML_TYPE_Q4_K:
- stream.extend((shader_q4_K_defines, dequant_q4_K_body))
- elif i == GGML_TYPE_Q5_K:
- stream.extend((shader_q5_K_defines, dequant_q5_K_body))
- elif i == GGML_TYPE_Q6_K:
- stream.extend((shader_q6_K_defines, dequant_q6_K_body))
- else:
- continue
-
- tasks.append(string_to_spv(f"dequant_{type_names[i]}", "".join(stream), {"D_TYPE": "float16_t"}))
-
- # get_rows
- for i in range(0, VK_NUM_TYPES):
- stream.clear()
- stream.extend((generic_binary_op_head, shader_int8_ext, shader_f32))
- optimization_workaround = False
-
- if i == GGML_TYPE_F32:
- stream.extend((shader_f32_defines, generic_binary_op_layout, generic_binary_op_funcs, get_rows_float_body))
- elif i == GGML_TYPE_F16:
- stream.extend((shader_f16_defines, generic_binary_op_layout, generic_binary_op_funcs, get_rows_float_body))
- optimization_workaround = True
- elif i == GGML_TYPE_Q4_0:
- stream.extend((shader_q4_0_defines, generic_binary_op_layout, shader_q4_0_dequant_func, generic_binary_op_funcs, get_rows_body))
- elif i == GGML_TYPE_Q4_1:
- stream.extend((shader_q4_1_defines, generic_binary_op_layout, shader_q4_1_dequant_func, generic_binary_op_funcs, get_rows_body))
- elif i == GGML_TYPE_Q5_0:
- stream.extend((shader_q5_0_defines, generic_binary_op_layout, shader_q5_0_dequant_func, generic_binary_op_funcs, get_rows_body))
- elif i == GGML_TYPE_Q5_1:
- stream.extend((shader_q5_1_defines, generic_binary_op_layout, shader_q5_1_dequant_func, generic_binary_op_funcs, get_rows_body))
- elif i == GGML_TYPE_Q8_0:
- stream.extend((shader_q8_0_defines, generic_binary_op_layout, shader_q8_0_dequant_func, generic_binary_op_funcs, get_rows_body))
- else:
- continue
-
- if optimization_workaround:
- tasks.append(string_to_spv(f"get_rows_{type_names[i]}", "".join(stream), {"B_TYPE": "int", "D_TYPE": "float16_t", "OPTIMIZATION_ERROR_WORKAROUND": "1"}))
- else:
- tasks.append(string_to_spv(f"get_rows_{type_names[i]}", "".join(stream), {"B_TYPE": "int", "D_TYPE": "float16_t"}))
- tasks.append(string_to_spv(f"get_rows_{type_names[i]}_f32", "".join(stream), {"B_TYPE": "int", "D_TYPE": "float"}))
-
- tasks.append(string_to_spv("mul_mat_vec_p021_f16_f32", mul_mat_p021_src, {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}))
- tasks.append(string_to_spv("mul_mat_vec_nc_f16_f32", mul_mat_nc_src, {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}))
-
- # Norms
- tasks.append(string_to_spv("norm_f32", f"{generic_head}\n{shader_f32}\n{norm_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
- tasks.append(string_to_spv("rms_norm_f32", f"{generic_head}\n{shader_f32}\n{rms_norm_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
-
- tasks.append(string_to_spv("cpy_f32_f32", f"{generic_unary_op_combined}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float"}))
- tasks.append(string_to_spv("cpy_f32_f16", f"{generic_unary_op_combined}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float16_t"}))
- tasks.append(string_to_spv("cpy_f16_f16", f"{generic_unary_op_combined}\n{cpy_f16_f16_end}", {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
-
- tasks.append(string_to_spv("add_f32", f"{generic_binary_op_combined}\n{add_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
-
- tasks.append(string_to_spv("split_k_reduce", mulmat_split_k_reduce_src, {}))
- tasks.append(string_to_spv("mul_f32", f"{generic_binary_op_combined}\n{mul_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
-
- tasks.append(string_to_spv("scale_f32", f"{generic_unary_op_combined}\n{scale_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
-
- tasks.append(string_to_spv("sqr_f32", f"{generic_unary_op_combined}\n{sqr_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
-
- tasks.append(string_to_spv("clamp_f32", f"{generic_unary_op_combined}\n{clamp_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
-
- tasks.append(string_to_spv("gelu_f32", f"{generic_head}\n{shader_f32}\n{gelu_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
- tasks.append(string_to_spv("silu_f32", f"{generic_head}\n{shader_f32}\n{silu_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
- tasks.append(string_to_spv("relu_f32", f"{generic_head}\n{shader_f32}\n{relu_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
-
- tasks.append(string_to_spv("diag_mask_inf_f32", f"{diag_mask_inf_head}\n{shader_f32}\n{diag_mask_inf_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
-
- tasks.append(string_to_spv("soft_max_f32", f"{soft_max_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float", "C_TYPE": "float", "D_TYPE": "float"}))
- tasks.append(string_to_spv("soft_max_f32_f16", f"{soft_max_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float16_t", "C_TYPE": "float16_t", "D_TYPE": "float"}))
-
- tasks.append(string_to_spv("rope_f32", rope_src, {"A_TYPE": "float", "D_TYPE": "float"}))
- tasks.append(string_to_spv("rope_f16", rope_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
-
- tasks.append(string_to_spv("rope_neox_f32", rope_neox_src, {"A_TYPE": "float", "D_TYPE": "float"}))
- tasks.append(string_to_spv("rope_neox_f16", rope_neox_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
-
- tasks.append(string_to_spv("argsort_f32", argsort_src, {"A_TYPE": "float"}))
-
- # Helper to decorate tasks with semaphore acquisition.
- async def withSemaphore(sem, task):
- async with sem:
- return await task
-
- # Run tasks concurrently guarded by a concurrency limit.
- sem = asyncio.Semaphore(ASYNCIO_CONCURRENCY)
- await asyncio.gather(*(withSemaphore(sem, task) for task in tasks))
-
- with open("ggml-vulkan-shaders.hpp", "w") as f:
- f.write("#include <cstdint>\n\n")
- for name, path in sorted(shader_fnames):
-
- with open(path, "rb") as spv:
- counter = 0
- newline_counter = 0
- f.write(f"unsigned char {name}_data[] = {{\n")
- for val in spv.read():
- f.write(f"0x{val:02x},")
- newline_counter += 1
- counter += 1
- if newline_counter >= 12:
- newline_counter = 0
- f.write("\n")
- f.write("\n};\n")
- f.write(f"const uint64_t {name}_len = {counter};\n\n")
- os.remove(path)
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="GGML Vulkan Shader Generator")
-
- parser.add_argument("--glslc", help="Path to glslc")
- parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
-
- args = parser.parse_args()
-
- logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
-
- if args.glslc:
- GLSLC = args.glslc
-
- asyncio.run(main())
\ No newline at end of file
#include "generic_binary_head.comp"
void main() {
- if (gl_GlobalInvocationID.x >= p.ne) {
+ const uint idx = get_idx();
+
+ if (idx >= p.ne) {
return;
}
- data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) + FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)]));
+ data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[src1_idx(idx)]));
}
#include "generic_unary_head.comp"
void main() {
- if (gl_GlobalInvocationID.x >= p.ne) {
+ const uint idx = get_idx();
+
+ if (idx >= p.ne) {
return;
}
- const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]);
- data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
+ const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
+ data_d[p.d_offset + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
}
--- /dev/null
+#version 450
+
+#include "types.comp"
+#include "generic_binary_head.comp"
+
+void main() {
+ const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+ const int dim = p.param3;
+
+ if (idx >= p.ne) {
+ return;
+ }
+
+ const uint i3 = idx / (p.ne22*p.ne21*p.ne20);
+ const uint i3_offset = i3 * p.ne22*p.ne21*p.ne20;
+ const uint i2 = (idx - i3_offset) / (p.ne21*p.ne20);
+ const uint i2_offset = i2*p.ne21*p.ne20;
+ const uint i1 = (idx - i3_offset - i2_offset) / p.ne20;
+ const uint i0 = idx - i3_offset - i2_offset - i1*p.ne20;
+
+ uint o[4] = {0, 0, 0, 0};
+ o[dim] = dim == 0 ? p.ne00 : (dim == 1 ? p.ne01 : (dim == 2 ? p.ne02 : p.ne03));
+
+ const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00;
+ const uint src1_idx = (i3 - o[3])*p.nb13 + (i2 - o[2])*p.nb12 + (i1 - o[1])*p.nb11 + (i0 - o[0])*p.nb10;
+ const uint dst_idx = i3*p.nb23 + i2*p.nb22 + i1*p.nb21 + i0*p.nb20;
+
+ const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
+
+#ifndef OPTIMIZATION_ERROR_WORKAROUND
+ data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : data_b[src1_idx]);
+#else
+ data_d[p.d_offset + dst_idx] = is_src0 ? data_a[src0_idx] : data_b[src1_idx];
+#endif
+}
#include "generic_unary_head.comp"
void main() {
- if (gl_GlobalInvocationID.x >= p.ne) {
+ const uint idx = get_idx();
+
+ if (idx >= p.ne) {
return;
}
#ifndef OPTIMIZATION_ERROR_WORKAROUND
- data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]);
+ data_d[p.d_offset + dst_idx(idx)] = D_TYPE(data_a[src0_idx(idx)]);
#else
- data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = data_a[src0_idx(gl_GlobalInvocationID.x)];
+ data_d[p.d_offset + dst_idx(idx)] = data_a[src0_idx(idx)];
#endif
}
#include "generic_binary_head.comp"
void main() {
- if (gl_GlobalInvocationID.x >= p.ne) {
+ const uint idx = get_idx();
+
+ if (idx >= p.ne) {
return;
}
- data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) / FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)]));
+ data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) / FLOAT_TYPE(data_b[src1_idx(idx)]));
}
void main() {
const float GELU_COEF_A = 0.044715f;
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
- const uint i = gl_GlobalInvocationID.x;
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
--- /dev/null
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+ const float GELU_QUICK_COEF = -1.702f;
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+ if (i >= p.KX) {
+ return;
+ }
+
+ const float x = float(data_a[i]);
+ data_d[i] = D_TYPE(x * (1.0f / (1.0f + exp(GELU_QUICK_COEF * x))));
+}
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
uint d_offset;
- float param1; float param2;
+ float param1; float param2; int param3;
} p;
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
+uint get_idx() {
+ return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+}
+
uint src0_idx(uint idx) {
const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+uint get_idx() {
+ return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+}
+
uint src0_idx(uint idx) {
const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
--- /dev/null
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+#define BLOCK_SIZE 512
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+shared float tmp[BLOCK_SIZE];
+
+void main() {
+ const uint group_size = p.KX;
+ const float eps = p.param1;
+
+ const uint tid = gl_LocalInvocationID.x;
+ const uint start = gl_WorkGroupID.x * group_size + tid;
+ const uint end = start + group_size;
+
+ tmp[tid] = 0.0f;
+
+ // Calculate mean
+ [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) {
+ tmp[tid] += float(data_a[col]);
+ }
+
+ // tmp up partial tmps and write back result
+ barrier();
+ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ tmp[tid] += tmp[tid + s];
+ }
+ barrier();
+ }
+
+ const float mean = tmp[0] / group_size;
+ barrier();
+ tmp[tid] = 0.0f;
+
+ // Calculate variance
+ [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) {
+ const float xi = float(data_a[col]) - mean;
+ data_d[col] = D_TYPE(xi);
+ tmp[tid] += xi * xi;
+ }
+
+ // sum up partial sums and write back result
+ barrier();
+ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ tmp[tid] += tmp[tid + s];
+ }
+ barrier();
+ }
+
+ const float variance = tmp[0] / group_size;
+ const float scale = inversesqrt(variance + eps);
+
+ [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) {
+ data_d[col] *= D_TYPE(scale);
+ }
+}
--- /dev/null
+#version 450
+
+#extension GL_EXT_shader_16bit_storage : require
+
+layout (push_constant) uniform parameter
+{
+ uint batch_offset; uint offset_delta;
+ uint IC;
+ uint IW; uint IH;
+ uint OW; uint OH;
+ uint KW; uint KH;
+ uint pelements;
+ uint CHW;
+ int s0; int s1;
+ int p0; int p1;
+ int d0; int d1;
+} p;
+
+#include "types.comp"
+
+#define BLOCK_SIZE 256
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+ const uint i = gl_GlobalInvocationID.x;
+ if (i >= p.pelements) {
+ return;
+ }
+
+ const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
+ const uint kx = i / ksize;
+ const uint kd = kx * ksize;
+ const uint ky = (i - kd) / p.OW;
+ const uint ix = i % p.OW;
+
+ const uint oh = gl_GlobalInvocationID.y;
+ const uint batch = gl_GlobalInvocationID.z / p.IC;
+ const uint ic = gl_GlobalInvocationID.z % p.IC;
+
+ const uint iiw = ix * p.s0 + kx * p.d0 - p.p0;
+ const uint iih = oh * p.s1 + ky * p.d1 - p.p1;
+
+ const uint offset_dst =
+ ((batch * p.OH + oh) * p.OW + ix) * p.CHW +
+ (ic * (p.KW * p.KH) + ky * p.KW + kx);
+
+ if (iih < 0 || iih >= p.IH || iiw < 0 || iiw >= p.IW) {
+ data_d[offset_dst] = D_TYPE(0.0f);
+ } else {
+ const uint offset_src = ic * p.offset_delta + batch * p.batch_offset;
+ data_d[offset_dst] = D_TYPE(data_a[offset_src + iih * p.IW + iiw]);
+ }
+}
--- /dev/null
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+ if (i >= p.KX) {
+ return;
+ }
+
+ const float val = float(data_a[i]);
+ data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1);
+}
#include "generic_binary_head.comp"
void main() {
- if (gl_GlobalInvocationID.x >= p.ne) {
+ const uint idx = get_idx();
+
+ if (idx >= p.ne) {
return;
}
- data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) * FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)]));
+ data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) * FLOAT_TYPE(data_b[src1_idx(idx)]));
}
shared vec2 sum[BLOCK_SIZE];
void main() {
- const uint row = gl_WorkGroupID.x;
+ const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
sum[tid] = vec2(0.0f, 0.0f);
--- /dev/null
+#version 450
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+
+void main() {
+ const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+ if (idx >= p.ne) {
+ return;
+ }
+
+ const uint i3 = idx / (p.ne12*p.ne11*p.ne10);
+ const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10;
+ const uint i2 = (idx - i3_offset) / (p.ne11*p.ne10);
+ const uint i2_offset = i2*p.ne11*p.ne10;
+ const uint i1 = (idx - i3_offset - i2_offset) / p.ne10;
+ const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10;
+
+ const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00;
+ const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10;
+
+ const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
+
+ data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : 0.0f);
+}
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
- const uint i = gl_GlobalInvocationID.x;
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
shared FLOAT_TYPE sum[BLOCK_SIZE];
void main() {
- const uint row = gl_WorkGroupID.x;
+ const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
#include "generic_unary_head.comp"
void main() {
- if (gl_GlobalInvocationID.x >= p.ne) {
+ const uint idx = get_idx();
+
+ if (idx >= p.ne) {
return;
}
- data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) * FLOAT_TYPE(p.param1));
+ data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) * FLOAT_TYPE(p.param1));
}
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
- const uint i = gl_GlobalInvocationID.x;
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
void main() {
const uint tid = gl_LocalInvocationID.x;
- const uint rowx = gl_WorkGroupID.x;
+ const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint rowy = rowx % p.KY;
float slope = 1.0f;
#include "generic_unary_head.comp"
void main() {
- if (gl_GlobalInvocationID.x >= p.ne) {
+ const uint idx = get_idx();
+
+ if (idx >= p.ne) {
return;
}
- const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]);
- data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(val * val);
+ const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
+ data_d[p.d_offset + dst_idx(idx)] = D_TYPE(val * val);
}
shared FLOAT_TYPE tmp[BLOCK_SIZE];
void main() {
- const uint row = gl_WorkGroupID.x;
+ const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint col = gl_LocalInvocationID.x;
tmp[col] = FLOAT_TYPE(0.0f);
--- /dev/null
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+ if (i >= p.KX) {
+ return;
+ }
+
+ data_d[i] = D_TYPE(tanh(data_a[i]));
+}
--- /dev/null
+#version 450
+
+#extension GL_EXT_shader_16bit_storage : require
+
+layout (push_constant) uniform parameter
+{
+ uint nb1;
+ uint dim;
+ uint max_period;
+} p;
+
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+#define BLOCK_SIZE 256
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+ const uint i = gl_WorkGroupID.y;
+ const uint j = gl_GlobalInvocationID.x;
+ const uint d_offset = i * p.nb1;
+
+ if (p.dim % 2 != 0 && j == ((p.dim + 1) / 2)) {
+ data_d[d_offset + p.dim] = 0.f;
+ }
+
+ const uint half_dim = p.dim / 2;
+ if (j >= half_dim) {
+ return;
+ }
+
+ const float timestep = float(data_a[i]);
+ const float freq = float(exp(-log(p.max_period) * j / half_dim));
+ const float arg = timestep * freq;
+ data_d[d_offset + j] = D_TYPE(cos(arg));
+ data_d[d_offset + j + half_dim] = D_TYPE(sin(arg));
+}
#define QUANT_K 1
#define QUANT_R 1
-#ifndef LOAD_VEC_A
+#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
#define A_TYPE float
#elif LOAD_VEC_A == 4
#define A_TYPE vec4
#define QUANT_K 1
#define QUANT_R 1
-#ifndef LOAD_VEC_A
+#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
#define A_TYPE float16_t
#elif LOAD_VEC_A == 4
#define A_TYPE f16vec4
--- /dev/null
+#version 450
+
+layout (push_constant) uniform parameter
+{
+ uint ne; uint d_offset;
+ uint nb00; uint nb01; uint nb02; uint nb03;
+ uint ne10; uint ne11; uint ne12; uint ne13;
+ float sf0; float sf1; float sf2; float sf3;
+} p;
+
+#include "types.comp"
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+ const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+ if (idx >= p.ne) {
+ return;
+ }
+
+ const uint i10 = idx % p.ne10;
+ const uint i11 = (idx / p.ne10) % p.ne11;
+ const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12;
+ const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13;
+
+ const uint i00 = uint(i10 / p.sf0);
+ const uint i01 = uint(i11 / p.sf1);
+ const uint i02 = uint(i12 / p.sf2);
+ const uint i03 = uint(i13 / p.sf3);
+
+ data_d[p.d_offset + idx] = D_TYPE(data_a[i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]);
+}
for (const auto& tname : type_names) {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
+ // For unaligned, load one at a time for f32/f16, or two at a time for quants
+ std::string load_vec_a_unaligned = (tname == "f32" || tname == "f16") ? "1" : "2";
+ // For aligned matmul loads
std::string load_vec_a = (tname == "f32" || tname == "f16") ? load_vec : "2";
tasks.push_back(std::async(std::launch::async, [=] {
- string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16);
+ string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16);
}));
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv(shader_name + "_" + tname + "_f32_aligned", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}}), fp16);
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
}));
+ tasks.push_back(std::async(std::launch::async, [=] {
+ string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+ }));
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
}));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
}));
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
+ }));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
}));
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ }));
+
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
+ }));
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
+ }));
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}});
+ }));
+
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
+ }));
+
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
}));
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ }));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
}));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
}));
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ }));
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ }));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
}));
+
+ tasks.push_back(std::async(std::launch::async, [=] {
+ string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+ }));
+ tasks.push_back(std::async(std::launch::async, [=] {
+ string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
+ }));
+
+ tasks.push_back(std::async(std::launch::async, [=] {
+ string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+ }));
}
void write_output_files() {