#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
// Matrix-vector multiplication parameters
-#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
+#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
+
// Must be multiple of 4 to work with vectorized paths, and must divide
// mul_mat_vec wg size
-#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
-#define WEBGPU_MUL_MAT_VEC_TILE_K 256
+#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64
+#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256
+
+#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64
+#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256
+
+// Requires 32 threads per output (wg_size/outputs_per_wg == 32)
+#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8
+// Requires at least two (and multiple of 2) k-quant blocks per tile
+#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512
// default size for legacy matrix multiplication
#define WEBGPU_MUL_MAT_WG_SIZE 256
bool src_overlap;
bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
- return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap;
+ return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap &&
+ src_overlap == other.src_overlap;
}
};
std::vector<std::string> defines;
std::string variant = "mul_mat_vec";
- // src1 type (vector)
- switch (context.src1->type) {
- case GGML_TYPE_F32:
- defines.push_back("SRC1_INNER_TYPE=f32");
- variant += "_f32";
- break;
- case GGML_TYPE_F16:
- defines.push_back("SRC1_INNER_TYPE=f16");
- variant += "_f16";
- break;
- default:
- GGML_ABORT("Unsupported src1 type for mul_mat_vec shader");
- }
-
// src0 type (matrix row)
switch (context.src0->type) {
case GGML_TYPE_F32:
defines.push_back("SRC0_INNER_TYPE=f32");
defines.push_back("MUL_ACC_FLOAT");
+ variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("SRC0_INNER_TYPE=f16");
defines.push_back("MUL_ACC_FLOAT");
+ variant += "_f16";
break;
default:
{
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
std::string src0_name = src0_traits->type_name;
std::string type_upper = src0_name;
+ variant += "_" + src0_name;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
defines.push_back("BYTE_HELPERS");
}
}
+ // src1 type (vector)
+ switch (context.src1->type) {
+ case GGML_TYPE_F32:
+ defines.push_back("SRC1_INNER_TYPE=f32");
+ variant += "_f32";
+ break;
+ case GGML_TYPE_F16:
+ defines.push_back("SRC1_INNER_TYPE=f16");
+ variant += "_f16";
+ break;
+ default:
+ GGML_ABORT("Unsupported src1 type for mul_mat_vec shader");
+ }
+
// VEC/SCALAR controls
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
- uint32_t tile_k = WEBGPU_MUL_MAT_VEC_TILE_K;
- uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
+ uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K;
+ uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
+
+ if (key.src0_type >= GGML_TYPE_Q2_K) {
+ tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K;
+ outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
+ } else if (key.src0_type >= GGML_TYPE_Q4_0) {
+ tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K;
+ outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
+ }
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
defines.push_back(std::string("TILE_K=") + std::to_string(tile_k));
defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_binary_pipeline_key key = {
- .type = context.dst->type,
- .op = context.dst->op,
- .inplace = context.inplace,
- .overlap = context.overlap,
+ .type = context.dst->type,
+ .op = context.dst->op,
+ .inplace = context.inplace,
+ .overlap = context.overlap,
.src_overlap = context.src_overlap,
};
#include "ggml-backend-impl.h"
#include "ggml-impl.h"
#include "ggml-webgpu-shader-lib.hpp"
-#include "pre_wgsl.hpp"
#ifdef __EMSCRIPTEN__
# include <emscripten/emscripten.h>
#include <condition_variable>
#include <cstdint>
#include <cstring>
-#include <iostream>
+#ifdef GGML_WEBGPU_GPU_PROFILE
+# include <iomanip>
+#endif
+#if defined(GGML_WEBGPU_DEBUG) || defined(GGML_WEBGPU_CPU_PROFILE) || defined(GGML_WEBGPU_GPU_PROFILE)
+# include <iostream>
+#endif
#include <map>
#include <memory>
#include <mutex>
#include <optional>
#include <string>
+#include <utility>
#include <vector>
#define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))
#endif // GGML_WEBGPU_CPU_PROFILE
#ifdef GGML_WEBGPU_GPU_PROFILE
-# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 24
+# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 32
# define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps
#endif
/* Constants */
-#define WEBGPU_NUM_PARAM_BUFS 48u
-#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16u
+#define WEBGPU_NUM_PARAM_BUFS 96u
+#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u
#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
// Maximum number of in-flight submissions per-thread, to avoid exhausting the
// parameter buffer pool
-#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
+#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE)
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
-#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 16
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
-#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
+#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
// For operations which process a row in parallel, this seems like a reasonable
// default
wgpu::BufferUsage usage,
const char * label);
-struct webgpu_pool_bufs {
- wgpu::Buffer host_buf;
- wgpu::Buffer dev_buf;
-};
-
// Holds a pool of parameter buffers for WebGPU operations
struct webgpu_buf_pool {
- std::vector<webgpu_pool_bufs> free;
+ std::vector<wgpu::Buffer> free;
// The pool must be synchronized because
// 1. The memset pool is shared globally by every ggml buffer,
size_t cur_pool_size;
size_t max_pool_size;
wgpu::Device device;
- wgpu::BufferUsage host_buf_usage;
wgpu::BufferUsage dev_buf_usage;
size_t buf_size;
bool should_grow;
int num_bufs,
size_t buf_size,
wgpu::BufferUsage dev_buf_usage,
- wgpu::BufferUsage host_buf_usage,
bool should_grow = false,
size_t max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) {
- this->max_pool_size = max_pool_size;
- this->cur_pool_size = num_bufs;
- this->device = device;
- this->host_buf_usage = host_buf_usage;
- this->dev_buf_usage = dev_buf_usage;
- this->buf_size = buf_size;
- this->should_grow = should_grow;
+ this->max_pool_size = max_pool_size;
+ this->cur_pool_size = num_bufs;
+ this->device = device;
+ this->dev_buf_usage = dev_buf_usage;
+ this->buf_size = buf_size;
+ this->should_grow = should_grow;
for (int i = 0; i < num_bufs; i++) {
- wgpu::Buffer host_buf;
wgpu::Buffer dev_buf;
- ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
- free.push_back({ host_buf, dev_buf });
+ free.push_back(dev_buf);
}
}
- webgpu_pool_bufs alloc_bufs() {
+ wgpu::Buffer alloc_bufs() {
std::unique_lock<std::mutex> lock(mutex);
if (!free.empty()) {
- webgpu_pool_bufs bufs = free.back();
+ wgpu::Buffer buf = free.back();
free.pop_back();
- return bufs;
+ return buf;
}
// Try growing the pool if no free buffers
if (free.empty() && cur_pool_size < max_pool_size && should_grow) {
cur_pool_size++;
- wgpu::Buffer host_buf;
wgpu::Buffer dev_buf;
- ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
- if (!(host_buf && dev_buf)) {
+ if (!dev_buf) {
GGML_ABORT("webgpu_buf_pool: failed to allocate buffers");
}
- return webgpu_pool_bufs{ host_buf, dev_buf };
+ return dev_buf;
}
cv.wait(lock, [this] { return !free.empty(); });
- webgpu_pool_bufs bufs = free.back();
+ wgpu::Buffer buf = free.back();
free.pop_back();
- return bufs;
+ return buf;
}
- void free_bufs(std::vector<webgpu_pool_bufs> bufs) {
+ void free_bufs(std::vector<wgpu::Buffer> bufs) {
std::lock_guard<std::mutex> lock(mutex);
free.insert(free.end(), bufs.begin(), bufs.end());
cv.notify_all();
void cleanup() {
std::lock_guard<std::mutex> lock(mutex);
- for (auto & bufs : free) {
- if (bufs.host_buf) {
- bufs.host_buf.Destroy();
- }
- if (bufs.dev_buf) {
- bufs.dev_buf.Destroy();
+ for (auto & buf : free) {
+ if (buf) {
+ buf.Destroy();
}
}
free.clear();
#endif
struct webgpu_command {
- uint32_t num_kernels;
- wgpu::CommandBuffer commands;
- std::vector<webgpu_pool_bufs> params_bufs;
- std::optional<webgpu_pool_bufs> set_rows_error_bufs;
+ uint32_t num_kernels;
+ wgpu::CommandBuffer commands;
+ std::vector<wgpu::Buffer> params_bufs;
#ifdef GGML_WEBGPU_GPU_PROFILE
webgpu_gpu_profile_bufs timestamp_query_bufs;
std::string pipeline_name;
typedef std::shared_ptr<webgpu_global_context_struct> webgpu_global_context;
+struct webgpu_submission {
+ wgpu::FutureWaitInfo submit_done;
+#ifdef GGML_WEBGPU_GPU_PROFILE
+ std::vector<wgpu::FutureWaitInfo> profile_futures;
+#endif
+};
+
// All the base objects needed to run operations on a WebGPU device
struct webgpu_context_struct {
// Points to global instances owned by ggml_backend_webgpu_reg_context
std::unique_ptr<ggml_webgpu_shader_lib> shader_lib;
webgpu_buf_pool param_buf_pool;
- webgpu_buf_pool set_rows_error_buf_pool;
+ wgpu::Buffer set_rows_dev_error_buf;
+ wgpu::Buffer set_rows_host_error_buf;
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
/** End WebGPU object initializations */
/** WebGPU Actions */
-static void erase_completed(std::vector<wgpu::FutureWaitInfo> & futures) {
+
+static bool ggml_backend_webgpu_handle_wait_status(wgpu::WaitStatus status, bool allow_timeout = false) {
+ switch (status) {
+ case wgpu::WaitStatus::Success:
+ return true;
+ case wgpu::WaitStatus::TimedOut:
+ if (allow_timeout) {
+ return false;
+ }
+ GGML_LOG_ERROR("ggml_webgpu: WaitAny timed out unexpectedly\n");
+ return false;
+ case wgpu::WaitStatus::Error:
+ GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
+ return false;
+ default:
+ GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
+ return false;
+ }
+}
+
+#ifdef GGML_WEBGPU_GPU_PROFILE
+static void ggml_backend_webgpu_erase_completed_futures(std::vector<wgpu::FutureWaitInfo> & futures) {
futures.erase(std::remove_if(futures.begin(), futures.end(),
[](const wgpu::FutureWaitInfo & info) { return info.completed; }),
futures.end());
}
-// Wait for the queue to finish processing all submitted work
-static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
- std::vector<wgpu::FutureWaitInfo> & futures,
- bool block = true) {
- // If we have too many in-flight submissions, wait on the oldest one first.
+static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context & ctx,
+ std::vector<wgpu::FutureWaitInfo> & futures,
+ bool block) {
if (futures.empty()) {
return;
}
+
uint64_t timeout_ms = block ? UINT64_MAX : 0;
- while (futures.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
- auto waitStatus = ctx->instance.WaitAny(1, &futures[0], UINT64_MAX);
- if (waitStatus == wgpu::WaitStatus::Error) {
- GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
+ if (block) {
+ while (!futures.empty()) {
+ auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
+ if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {
+ ggml_backend_webgpu_erase_completed_futures(futures);
+ }
}
- if (futures[0].completed) {
- futures.erase(futures.begin());
+ } else {
+ auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
+ if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) {
+ ggml_backend_webgpu_erase_completed_futures(futures);
}
}
+}
+#endif
- if (futures.empty()) {
+// Wait for the queue to finish processing all submitted work
+static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
+ std::vector<webgpu_submission> & subs,
+ bool block = true) {
+ // If we have too many in-flight submissions, wait on the oldest one first.
+ if (subs.empty()) {
+ return;
+ }
+ while (subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
+ auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, UINT64_MAX);
+ if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {
+#ifdef GGML_WEBGPU_GPU_PROFILE
+ ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true);
+#endif
+ subs.erase(subs.begin());
+ }
+ }
+
+ if (subs.empty()) {
return;
}
if (block) {
- while (!futures.empty()) {
- auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
- switch (waitStatus) {
- case wgpu::WaitStatus::Success:
- // WaitAny doesn't tell us which future completed, so we must check all futures to see which finished.
- erase_completed(futures);
- break;
- case wgpu::WaitStatus::Error:
- GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
- break;
- default:
- GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
- break;
+ for (auto & sub : subs) {
+ while (!sub.submit_done.completed) {
+ auto waitStatus = ctx->instance.WaitAny(1, &sub.submit_done, UINT64_MAX);
+ ggml_backend_webgpu_handle_wait_status(waitStatus);
}
+#ifdef GGML_WEBGPU_GPU_PROFILE
+ ggml_backend_webgpu_wait_profile_futures(ctx, sub.profile_futures, true);
+#endif
}
+ subs.clear();
} else {
- // Poll once and return
- auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
- switch (waitStatus) {
- case wgpu::WaitStatus::Success:
- // WaitAny doesn't tell us which future completed, so we must check all futures to see which finished.
- erase_completed(futures);
- break;
- case wgpu::WaitStatus::TimedOut:
- break;
- case wgpu::WaitStatus::Error:
- GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
- break;
- default:
- GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
- break;
+ // Poll each submit future once and remove completed submissions.
+ for (auto sub = subs.begin(); sub != subs.end();) {
+ auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0);
+ ggml_backend_webgpu_handle_wait_status(waitStatus, true);
+#ifdef GGML_WEBGPU_GPU_PROFILE
+ ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false);
+ if (sub->submit_done.completed && sub->profile_futures.empty()) {
+#else
+ if (sub->submit_done.completed) {
+#endif
+ sub = subs.erase(sub);
+ } else {
+ ++sub;
+ }
}
}
}
}
#endif
-static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
- webgpu_global_context ctx,
- std::vector<webgpu_command> commands,
- webgpu_buf_pool & param_buf_pool,
- webgpu_buf_pool * set_rows_error_buf_pool = nullptr) {
+static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & ctx,
+ std::vector<webgpu_command> & commands,
+ webgpu_buf_pool & param_buf_pool) {
std::vector<wgpu::CommandBuffer> command_buffers;
- std::vector<webgpu_pool_bufs> params_bufs;
- std::vector<webgpu_pool_bufs> set_rows_error_bufs;
+ std::vector<wgpu::Buffer> params_bufs;
+ webgpu_submission submission;
#ifdef GGML_WEBGPU_GPU_PROFILE
std::vector<std::pair<std::string, webgpu_gpu_profile_bufs>> pipeline_name_and_ts_bufs;
#endif
for (const auto & command : commands) {
command_buffers.push_back(command.commands);
params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end());
- if (command.set_rows_error_bufs) {
- set_rows_error_bufs.push_back(command.set_rows_error_bufs.value());
- }
}
ctx->queue.Submit(command_buffers.size(), command_buffers.data());
- std::vector<wgpu::FutureWaitInfo> futures;
-
wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
wgpu::CallbackMode::AllowSpontaneous,
[¶m_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
// Free the staged buffers
param_buf_pool.free_bufs(params_bufs);
});
- futures.push_back({ p_f });
-
- for (const auto & bufs : set_rows_error_bufs) {
- wgpu::Future f = bufs.host_buf.MapAsync(
- wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
- [set_rows_error_buf_pool, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
- if (status != wgpu::MapAsyncStatus::Success) {
- GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
- } else {
- const uint32_t * error_data = (const uint32_t *) bufs.host_buf.GetConstMappedRange();
- if (*error_data) {
- GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
- }
- // We can't unmap in here due to WebGPU reentrancy limitations.
- if (set_rows_error_buf_pool) {
- set_rows_error_buf_pool->free_bufs({ bufs });
- }
- }
- });
- futures.push_back({ f });
- }
+ submission.submit_done = { p_f };
#ifdef GGML_WEBGPU_GPU_PROFILE
for (const auto & command : commands) {
// WebGPU timestamps are in ns; convert to ms
double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6;
ctx->shader_gpu_time_ms[label] += elapsed_ms;
- // We can't unmap in here due to WebGPU reentrancy limitations.
- ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
}
+ // We can't unmap in here due to WebGPU reentrancy limitations.
+ ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
});
- futures.push_back({ f });
+ submission.profile_futures.push_back({ f });
}
#endif
- return futures;
+ return submission;
}
static webgpu_command ggml_backend_webgpu_build_multi(
const std::vector<webgpu_pipeline> & pipelines,
const std::vector<std::vector<uint32_t>> & params_list,
const std::vector<std::vector<wgpu::BindGroupEntry>> & bind_group_entries_list,
- const std::vector<std::pair<uint32_t, uint32_t>> & workgroups_list,
- const std::optional<webgpu_pool_bufs> & set_rows_error_bufs = std::nullopt) {
+ const std::vector<std::pair<uint32_t, uint32_t>> & workgroups_list) {
GGML_ASSERT(pipelines.size() == params_list.size());
GGML_ASSERT(pipelines.size() == bind_group_entries_list.size());
GGML_ASSERT(pipelines.size() == workgroups_list.size());
- std::vector<webgpu_pool_bufs> params_bufs_list;
- std::vector<wgpu::BindGroup> bind_groups;
+ std::vector<wgpu::Buffer> params_bufs_list;
+ std::vector<wgpu::BindGroup> bind_groups;
for (size_t i = 0; i < pipelines.size(); i++) {
- webgpu_pool_bufs params_bufs = param_buf_pool.alloc_bufs();
-
- ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0,
- params_bufs.host_buf.GetSize());
- uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
- for (size_t j = 0; j < params_list[i].size(); j++) {
- _params[j] = params_list[i][j];
- }
- params_bufs.host_buf.Unmap();
+ wgpu::Buffer params_bufs = param_buf_pool.alloc_bufs();
std::vector<wgpu::BindGroupEntry> entries = bind_group_entries_list[i];
uint32_t params_binding_num = entries.size();
- entries.push_back({ .binding = params_binding_num,
- .buffer = params_bufs.dev_buf,
- .offset = 0,
- .size = params_bufs.dev_buf.GetSize() });
+ entries.push_back(
+ { .binding = params_binding_num, .buffer = params_bufs, .offset = 0, .size = params_bufs.GetSize() });
wgpu::BindGroupDescriptor bind_group_desc;
bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0);
}
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
- for (const auto & params_bufs : params_bufs_list) {
- encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
- }
-
- // If there are SET_ROWS operations in this submission, copy their error
- // buffers to the host.
- if (set_rows_error_bufs) {
- encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0,
- set_rows_error_bufs->host_buf.GetSize());
+ for (size_t i = 0; i < params_bufs_list.size(); i++) {
+ ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t));
}
#ifdef GGML_WEBGPU_GPU_PROFILE
webgpu_command result = {};
result.commands = commands;
result.params_bufs = params_bufs_list;
- result.set_rows_error_bufs = set_rows_error_bufs;
result.num_kernels = pipelines.size();
#ifdef GGML_WEBGPU_GPU_PROFILE
result.timestamp_query_bufs = ts_bufs;
std::vector<uint32_t> params,
std::vector<wgpu::BindGroupEntry> bind_group_entries,
uint32_t wg_x,
- uint32_t wg_y = 1,
- std::optional<webgpu_pool_bufs> set_rows_error_bufs = std::nullopt) {
+ uint32_t wg_y = 1) {
return ggml_backend_webgpu_build_multi(ctx, param_buf_pool,
{
pipeline
},
- { params }, { bind_group_entries }, { { wg_x, wg_y } }, set_rows_error_bufs);
+ { std::move(params) }, { std::move(bind_group_entries) },
+ { { wg_x, wg_y } });
}
static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
webgpu_command command =
ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
- auto futures = ggml_backend_webgpu_submit(ctx, { command }, ctx->memset_buf_pool);
- ggml_backend_webgpu_wait(ctx, futures);
+ std::vector<webgpu_command> commands = { command };
+ std::vector<webgpu_submission> sub = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) };
+ ggml_backend_webgpu_wait(ctx, sub);
}
/** End WebGPU Actions */
std::cout << "\nggml_webgpu: gpu breakdown:\n";
for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0;
- std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
+ std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << std::fixed << std::setprecision(2)
+ << pct << "%)\n";
}
#endif
auto * decisions = static_cast<ggml_webgpu_set_rows_shader_decisions *>(pipeline.context.get());
- std::optional<webgpu_pool_bufs> error_bufs = std::nullopt;
- if (decisions->i64_idx) {
- error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
- if (error_bufs->host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
- error_bufs->host_buf.Unmap();
- }
- }
-
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
};
if (decisions->i64_idx) {
- entries.push_back(
- { .binding = 3, .buffer = error_bufs->dev_buf, .offset = 0, .size = error_bufs->dev_buf.GetSize() });
+ entries.push_back({ .binding = 3,
+ .buffer = ctx->set_rows_dev_error_buf,
+ .offset = 0,
+ .size = ctx->set_rows_dev_error_buf.GetSize() });
}
uint32_t threads;
threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
}
uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size);
- return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1,
- error_bufs);
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1);
}
// Workgroup size is a common constant
use_fast = (src0->type == GGML_TYPE_F16);
break;
case GGML_TYPE_F32:
+ // TODO: implement better mat-mat for k-quants, mat-vec for all k-quants except q6_K
switch (src0->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q8_1:
+ case GGML_TYPE_Q6_K:
use_fast = true;
break;
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ // we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat
+ use_fast = !is_vec;
+ break;
default:
break;
}
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
if (use_fast && is_vec) {
- auto decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
+ auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
uint32_t batches = dst->ne[2] * dst->ne[3];
uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
uint32_t total_wg = output_groups * batches;
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
} else if (use_fast) {
- auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
+ auto * decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
// Fast-path tiled/subgroup calculations
- uint32_t wg_m, wg_n;
+ uint32_t wg_m;
+ uint32_t wg_n;
if (decisions->use_subgroup_matrix) {
uint32_t wg_m_sg_tile =
decisions->subgroup_m * decisions->subgroup_matrix_m * ctx->global_ctx->capabilities.sg_mat_m;
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
} else { // legacy
- auto decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t wg_size = decisions->wg_size;
uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
}
static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
- ggml_tensor * src0,
- ggml_tensor * src1,
- ggml_tensor * dst) {
- uint32_t ne = (uint32_t) ggml_nelements(dst);
+ ggml_tensor * src0,
+ ggml_tensor * src1,
+ ggml_tensor * dst) {
+ uint32_t ne = (uint32_t) ggml_nelements(dst);
uint32_t dim = (uint32_t) dst->op_params[0];
std::vector<uint32_t> params = {
(uint32_t) dst->ne[2],
(uint32_t) dst->ne[3],
dim,
- (uint32_t)src0->ne[dim]
+ (uint32_t) src0->ne[dim]
};
std::vector<wgpu::BindGroupEntry> entries = {
- {
- .binding = 0,
- .buffer = ggml_webgpu_tensor_buf(src0),
- .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
- .size = ggml_webgpu_tensor_binding_size(ctx, src0)
- },
- {
- .binding = 1,
- .buffer = ggml_webgpu_tensor_buf(src1),
- .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
- .size = ggml_webgpu_tensor_binding_size(ctx, src1)
- },
- {
- .binding = 2,
- .buffer = ggml_webgpu_tensor_buf(dst),
- .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
- .size = ggml_webgpu_tensor_binding_size(ctx, dst)
- }
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src0),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
+ { .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(src1),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) },
+ { .binding = 2,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
};
- webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
- auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
- uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
+ webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
+ uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+ float freq_base;
+ float freq_scale;
+ float ext_factor;
+ float attn_factor;
+ float beta_fast;
+ float beta_slow;
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
case GGML_OP_SOFT_MAX:
return ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
case GGML_OP_UNARY:
- return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_CLAMP:
- return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_FILL:
- return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_LOG:
- return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_SQR:
- return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_SQRT:
- return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_SIN:
- return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_COS:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_PAD:
case GGML_OP_ARGMAX:
return ggml_webgpu_argmax(ctx, src0, node);
case GGML_OP_ARGSORT:
- return ggml_webgpu_argsort(ctx, src0, node);
case GGML_OP_TOP_K:
// we reuse the same argsort implementation for top_k
return ggml_webgpu_argsort(ctx, src0, node);
WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
- std::vector<webgpu_command> commands;
- std::vector<wgpu::FutureWaitInfo> futures;
- uint32_t num_batched_kernels = 0;
+ std::vector<webgpu_command> commands;
+ std::vector<webgpu_submission> subs;
+ uint32_t num_batched_kernels = 0;
+ bool contains_set_rows = false;
+
for (int i = 0; i < cgraph->n_nodes; i++) {
+ if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) {
+ contains_set_rows = true;
+ }
if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
commands.push_back(*cmd);
num_batched_kernels += cmd.value().num_kernels;
}
if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
- num_batched_kernels = 0;
- std::vector<wgpu::FutureWaitInfo> compute_futures = ggml_backend_webgpu_submit(
- ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
- futures.insert(futures.end(), compute_futures.begin(), compute_futures.end());
+ num_batched_kernels = 0;
+ subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));
// Process events and check for completed submissions
ctx->global_ctx->instance.ProcessEvents();
- ggml_backend_webgpu_wait(ctx->global_ctx, futures, false);
+ ggml_backend_webgpu_wait(ctx->global_ctx, subs, false);
commands.clear();
}
}
if (!commands.empty()) {
- auto new_futures =
- ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
- futures.insert(futures.end(), new_futures.begin(), new_futures.end());
+ subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));
+ commands.clear();
+ }
+
+ // If there are SET_ROWS operations in this graph, copy the error buffers to the host for checking.
+ if (contains_set_rows) {
+ wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder();
+ encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0,
+ ctx->set_rows_host_error_buf.GetSize());
+ wgpu::CommandBuffer set_rows_commands = encoder.Finish();
+ ctx->global_ctx->queue.Submit(1, &set_rows_commands);
+ ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0,
+ ctx->set_rows_host_error_buf.GetSize());
+ const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange();
+ if (*error_data) {
+ GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
+ }
+ ctx->set_rows_host_error_buf.Unmap();
}
- ggml_backend_webgpu_wait(ctx->global_ctx, futures);
+ ggml_backend_webgpu_wait(ctx->global_ctx, subs);
WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx);
return GGML_STATUS_SUCCESS;
}
webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true);
- webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
- WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
- wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
- wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
+ ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf,
+ WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
+ wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf");
+ ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_host_error_buf,
+ WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
shmem[idx + 2] = val.z;
shmem[idx + 3] = val.w;
}
-#endif
+#endif // VEC
#ifdef SCALAR
#define VEC_SIZE 1
fn store_shmem(val: f16, idx: u32) {
shmem[idx] = val;
}
-#endif
+#endif // SCALAR
#ifdef INIT_SRC0_SHMEM_FLOAT
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
store_shmem(SHMEM_TYPE(src0_val), elem_idx);
}
}
-#endif
+#endif // INIT_SRC0_SHMEM_FLOAT
#ifdef INIT_SRC1_SHMEM_FLOAT
fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx);
}
}
-#endif
+#endif // INIT_SRC1_SHMEM_FLOAT
#ifdef INIT_SRC0_SHMEM_Q4_0
const BLOCK_SIZE = 32u;
}
}
}
-#endif
+#endif // INIT_SRC0_SHMEM_Q4_0
+
+#ifdef INIT_SRC0_SHMEM_Q4_1
+const BLOCK_SIZE = 32u;
+// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
+override BLOCKS_K = TILE_K/BLOCK_SIZE;
+const NQ = 16u;
+const F16_PER_BLOCK = 10u; // 1 scale + 8 packed weights + 1 mean
+const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+ for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
+ let blck_idx = i / BLOCK_SIZE;
+ let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+ let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+
+ let tile_m = blck_idx / BLOCKS_K;
+ let global_m = offset_m + tile_m;
+ let block_k = blck_idx % BLOCKS_K;
+ let global_k = k_outer / BLOCK_SIZE + block_k;
+
+ if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
+ let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
+ let scale_idx = src0_idx * F16_PER_BLOCK;
+ let d = src0[scale_idx];
+ let m = src0[scale_idx + 1u];
+
+ for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+ let q_0 = src0[scale_idx + 2u + block_offset + j];
+ let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
+
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ for (var k = 0u; k < 4u; k++) {
+ let q_byte = get_byte(q_packed, k);
+ let q_lo = f16(q_byte & 0xF) * d + m;
+ let q_hi = f16((q_byte >> 4) & 0xF) * d + m;
+ shmem[shmem_idx + j * 2 + k] = q_lo;
+ shmem[shmem_idx + j * 2 + k + 16u] = q_hi;
+ }
+ }
+ }
+ }
+}
+#endif // INIT_SRC0_SHMEM_Q4_1
+
+#ifdef INIT_SRC0_SHMEM_Q5_0
+// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
+const BLOCK_SIZE = 32u;
+// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
+// tile_k is defined as 32u, so blocks_k ends up being 1 always
+override BLOCKS_K = TILE_K / BLOCK_SIZE;
+const NQ = 16u;
+const F16_PER_BLOCK = 11u; // 1 scale + 2 qh + 8 packed weights
+const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+
+ for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
+ let blck_idx = i / BLOCK_SIZE;
+ let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+ let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+
+ let tile_m = blck_idx / BLOCKS_K;
+ let global_m = offset_m + tile_m;
+ let block_k = blck_idx % BLOCKS_K;
+ let global_k = k_outer / BLOCK_SIZE + block_k;
+
+ if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
+ let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
+ let scale_idx = src0_idx * F16_PER_BLOCK;
+
+ let d = src0[scale_idx];
+ let qh0 = src0[scale_idx + 1u];
+ let qh1 = src0[scale_idx + 2u];
+ let qh_packed = bitcast<u32>(vec2(qh0, qh1));
+
+ for (var j = 0u; j < 2; j++) {
+ let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
+ let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
+
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
+
+ let j_adjusted = j + (block_offset / 2u);
+
+
+ for (var k = 0u; k < 4u; k++) {
+ let q_byte = get_byte(q_packed, k);
+
+ let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
+ let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
+ let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
+ let q_lo = (f16((q_byte & 0xF) | qh_lo) - 16.0) * d;
+
+ shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight
+ shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight
+ }
+ }
+ }
+ }
+}
+#endif // INIT_SRC0_SHMEM_Q5_0
+
+#ifdef INIT_SRC0_SHMEM_Q5_1
+// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
+const BLOCK_SIZE = 32u;
+// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
+// tile_k is defined as 32u, so blocks_k ends up being 1 always
+override BLOCKS_K = TILE_K / BLOCK_SIZE;
+const NQ = 16u;
+const F16_PER_BLOCK = 12u; // 1 scale + 2 qh + 8 packed weights + 1 mean
+const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+
+ for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
+ let blck_idx = i / BLOCK_SIZE;
+ let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+ let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+
+ let tile_m = blck_idx / BLOCKS_K;
+ let global_m = offset_m + tile_m;
+ let block_k = blck_idx % BLOCKS_K;
+ let global_k = k_outer / BLOCK_SIZE + block_k;
+
+ if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
+ let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
+ let scale_idx = src0_idx * F16_PER_BLOCK;
+
+ let d = src0[scale_idx];
+ let m = src0[scale_idx + 1u];
+ let qh0 = src0[scale_idx + 2u];
+ let qh1 = src0[scale_idx + 3u];
+ let qh_packed = bitcast<u32>(vec2(qh0, qh1));
+
+ for (var j = 0u; j < 2; j++) {
+
+ let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
+ let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
+
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
+
+ let j_adjusted = j + (block_offset / 2u);
+
+
+ for (var k = 0u; k < 4u; k++) {
+ let q_byte = get_byte(q_packed, k);
+
+ let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
+ let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi)) * d + m;
+ let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
+ let q_lo = (f16((q_byte & 0xF) | qh_lo)) * d + m;
+
+ shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight
+ shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight
+ }
+ }
+ }
+ }
+}
+#endif // INIT_SRC0_SHMEM_Q5_1
+
+#ifdef INIT_SRC0_SHMEM_Q8_0
+const BLOCK_SIZE = 32u;
+// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
+override BLOCKS_K = TILE_K/BLOCK_SIZE;
+const NQ = 16u;
+const F16_PER_BLOCK = 17u; // 1 scale + 16 in array of weights
+const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+ for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
+ let blck_idx = i / BLOCK_SIZE;
+ let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+ let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+
+ let tile_m = blck_idx / BLOCKS_K;
+ let global_m = offset_m + tile_m;
+ let block_k = blck_idx % BLOCKS_K;
+ let global_k = k_outer / BLOCK_SIZE + block_k;
+
+ if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
+ let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
+ let scale_idx = src0_idx * F16_PER_BLOCK;
+ let d = src0[scale_idx];
+
+ for (var j = 0u; j < F16_PER_THREAD; j+=2) {
+ let q_0 = src0[scale_idx + 1u + block_offset + j];
+ let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
+
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ for (var k = 0u; k < 4u; k++) {
+ let q_byte = get_byte_i32(q_packed, k);
+
+ let q_val = f16(q_byte) * d;
+ shmem[shmem_idx + j * 2 + k] = q_val;
+ }
+ }
+ }
+ }
+}
+#endif // INIT_SRC0_SHMEM_Q8_0
+
+#ifdef INIT_SRC0_SHMEM_Q8_1
+const BLOCK_SIZE = 32u;
+// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
+override BLOCKS_K = TILE_K/BLOCK_SIZE;
+const NQ = 16u;
+const F16_PER_BLOCK = 18u; // 1 scale + 1 mean + 8 32-bit values in array of weights
+const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+ for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
+ let blck_idx = i / BLOCK_SIZE;
+ let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+ let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+
+ let tile_m = blck_idx / BLOCKS_K;
+ let global_m = offset_m + tile_m;
+ let block_k = blck_idx % BLOCKS_K;
+ let global_k = k_outer / BLOCK_SIZE + block_k;
+
+ if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
+ let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
+ let scale_idx = src0_idx * F16_PER_BLOCK;
+ let d = src0[scale_idx];
+ let m = src0[scale_idx + 1u];
+
+ for (var j = 0u; j < F16_PER_THREAD; j+=2) {
+ let q_0 = src0[scale_idx + 2u + block_offset + j];
+ let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
+
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ for (var k = 0u; k < 4u; k++) {
+ let q_byte = get_byte_i32(q_packed, k);
+
+ let q_val = f16(q_byte) * d + m;
+ shmem[shmem_idx + j * 2 + k] = q_val;
+ }
+ }
+ }
+ }
+}
+#endif // INIT_SRC0_SHMEM_Q8_1
+
+#ifdef INIT_SRC0_SHMEM_Q2_K
+const BLOCK_SIZE = 256u;
+const F16_PER_BLOCK = 42u;
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+ // Use standard thread layout instead of lane/row_group
+ for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
+ let tile_m = elem_idx / TILE_K;
+ let tile_k = elem_idx % TILE_K;
+
+ let global_m = offset_m + tile_m;
+ let global_k = k_outer + tile_k;
+
+ if (global_m >= params.m || global_k >= params.k) {
+ shmem[elem_idx] = f16(0.0);
+ continue;
+ }
+
+ let block_k = global_k / BLOCK_SIZE;
+ let k_in_block = global_k % BLOCK_SIZE;
+
+ let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
+ let scale_idx = src0_idx * F16_PER_BLOCK;
+
+ let d = src0[scale_idx + 40u];
+ let dmin = src0[scale_idx + 41u];
+
+ // Decode the element at position k_in_block
+ let block_of_32 = k_in_block / 32u;
+ let pos_in_32 = k_in_block % 32u;
+
+ let q_b_idx = (block_of_32 / 4u) * 32u;
+ let shift = (block_of_32 % 4u) * 2u;
+ let k = (pos_in_32 / 16u) * 16u;
+ let l = pos_in_32 % 16u;
+
+ let is = k_in_block / 16u;
+
+ let sc_0 = src0[scale_idx + 2u * (is / 4u)];
+ let sc_1 = src0[scale_idx + 2u * (is / 4u) + 1u];
+ let sc_packed = bitcast<u32>(vec2(sc_0, sc_1));
+ let sc = get_byte(sc_packed, is % 4u);
+
+ let dl = d * f16(sc & 0xFu);
+ let ml = dmin * f16(sc >> 4u);
+
+ let q_idx = q_b_idx + k + l;
+ let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
+ let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ let q_byte = get_byte(q_packed, q_idx % 4u);
+ let qs_val = (q_byte >> shift) & 3u;
+
+ let q_val = f16(qs_val) * dl - ml;
+ shmem[elem_idx] = q_val;
+ }
+}
+#endif // INIT_SRC0_SHMEM_Q2_K
+
+#ifdef INIT_SRC0_SHMEM_Q3_K
+const BLOCK_SIZE = 256u;
+const F16_PER_BLOCK = 55u;
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+ for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
+ let tile_m = elem_idx / TILE_K;
+ let tile_k = elem_idx % TILE_K;
+
+ let global_m = offset_m + tile_m;
+ let global_k = k_outer + tile_k;
+
+ if (global_m >= params.m || global_k >= params.k) {
+ shmem[elem_idx] = f16(0.0);
+ continue;
+ }
+
+ let block_k = global_k / BLOCK_SIZE;
+ let k_in_block = global_k % BLOCK_SIZE;
+
+ let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
+ let scale_idx = src0_idx * F16_PER_BLOCK;
+
+ let d = src0[scale_idx + 54u];
+
+ // Load and unpack scales
+ let kmask1: u32 = 0x03030303u;
+ let kmask2: u32 = 0x0f0f0f0fu;
+
+ var scale_vals: array<u32, 4>;
+ for (var i: u32 = 0u; i < 4u; i++) {
+ let scale_0 = src0[scale_idx + 48u + (2u*i)];
+ let scale_1 = src0[scale_idx + 48u + (2u*i) + 1u];
+ scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
+ }
+
+ var tmp: u32 = scale_vals[2];
+ scale_vals[2] = ((scale_vals[0] >> 4u) & kmask2) | (((tmp >> 4u) & kmask1) << 4u);
+ scale_vals[3] = ((scale_vals[1] >> 4u) & kmask2) | (((tmp >> 6u) & kmask1) << 4u);
+ scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4u);
+ scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2u) & kmask1) << 4u);
+
+ // Load hmask and qs arrays
+ var hmask_vals: array<u32, 8>;
+ for (var i: u32 = 0u; i < 8u; i++) {
+ let hmask_0 = src0[scale_idx + (2u*i)];
+ let hmask_1 = src0[scale_idx + (2u*i) + 1u];
+ hmask_vals[i] = bitcast<u32>(vec2(hmask_0, hmask_1));
+ }
+
+ var qs_vals: array<u32, 16>;
+ for (var i: u32 = 0u; i < 16u; i++) {
+ let qs_0 = src0[scale_idx + 16u + (2u*i)];
+ let qs_1 = src0[scale_idx + 16u + (2u*i) + 1u];
+ qs_vals[i] = bitcast<u32>(vec2(qs_0, qs_1));
+ }
+
+ let half = k_in_block / 128u; // 0 or 1
+ let pos_in_half = k_in_block % 128u; // 0-127
+ let shift_group = pos_in_half / 32u; // 0-3
+ let pos_in_32 = pos_in_half % 32u; // 0-31
+ let k_group = pos_in_32 / 16u; // 0 or 1
+ let l = pos_in_32 % 16u; // 0-15
+
+ let q_b_idx = half * 32u; // 0 or 32
+ let shift = shift_group * 2u; // 0, 2, 4, 6
+ let k = k_group * 16u; // 0 or 16
+ let is = k_in_block / 16u; // 0-15
+
+ // m increments every 32 elements across entire 256 element block
+ let m_shift = k_in_block / 32u; // 0-7
+ let m: u32 = 1u << m_shift; // 1,2,4,8,16,32,64,128
+
+ let sc = get_byte(scale_vals[is / 4u], is % 4u);
+ let dl = d * (f16(sc) - 32.0);
+
+ let q_idx = q_b_idx + k + l;
+ let hm_idx = k + l;
+
+ let q_byte = get_byte(qs_vals[q_idx / 4u], q_idx % 4u);
+ let hmask_byte = get_byte(hmask_vals[hm_idx / 4u], hm_idx % 4u);
+
+ let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
+ let qs_val = (q_byte >> shift) & 3u;
+
+ let q_val = (f16(qs_val) - f16(hm)) * dl;
+ shmem[elem_idx] = q_val;
+ }
+}
+
+#endif // INIT_SRC0_SHMEM_Q3_K
+
+#ifdef INIT_SRC0_SHMEM_Q4_K
+const BLOCK_SIZE = 256u;
+const F16_PER_BLOCK = 72u;
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+ for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
+ let tile_m = elem_idx / TILE_K;
+ let tile_k = elem_idx % TILE_K;
+
+ let global_m = offset_m + tile_m;
+ let global_k = k_outer + tile_k;
+
+ if (global_m >= params.m || global_k >= params.k) {
+ shmem[elem_idx] = f16(0.0);
+ continue;
+ }
+
+ let block_k = global_k / BLOCK_SIZE;
+ let k_in_block = global_k % BLOCK_SIZE;
+
+ let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
+ let scale_idx = src0_idx * F16_PER_BLOCK;
+
+ let d = src0[scale_idx];
+ let dmin = src0[scale_idx + 1u];
+
+ // Load packed scales
+ var scale_vals: array<u32, 3>;
+ for (var i: u32 = 0u; i < 3u; i++) {
+ let scale_0 = src0[scale_idx + 2u + (2u*i)];
+ let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
+ scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
+ }
+
+ // Map k_in_block to loop structure:
+ // Outer loop over 64-element groups (alternating q_b_idx)
+ // Inner loop over 2 shifts per group
+ let group_of_64 = k_in_block / 64u; // 0-3 (maps to q_b_idx)
+ let pos_in_64 = k_in_block % 64u; // 0-63
+ let shift_group = pos_in_64 / 32u; // 0 or 1
+ let l = pos_in_64 % 32u; // 0-31
+
+ let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96
+ let shift = shift_group * 4u; // 0 or 4
+ let is = k_in_block / 32u; // 0-7
+
+ var sc: u32;
+ var mn: u32;
+
+ if (is < 4u) {
+ let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
+ let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
+ sc = sc_byte & 63u;
+ mn = min_byte & 63u;
+ } else {
+ let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
+ let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
+ let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
+
+ sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
+ mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
+ }
+
+ let dl = d * f16(sc);
+ let ml = dmin * f16(mn);
+
+ let q_idx = q_b_idx + l;
+ let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
+ let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
+
+ let q_byte = get_byte(q_packed, q_idx % 4u);
+ let qs_val = (q_byte >> shift) & 0xFu;
+
+ let q_val = f16(qs_val) * dl - ml;
+ shmem[elem_idx] = q_val;
+ }
+}
+#endif // INIT_SRC0_SHMEM_Q4_K
+
+#ifdef INIT_SRC0_SHMEM_Q5_K
+const BLOCK_SIZE = 256u;
+const F16_PER_BLOCK = 88u;
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+ for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
+ let tile_m = elem_idx / TILE_K;
+ let tile_k = elem_idx % TILE_K;
+
+ let global_m = offset_m + tile_m;
+ let global_k = k_outer + tile_k;
+
+ if (global_m >= params.m || global_k >= params.k) {
+ shmem[elem_idx] = f16(0.0);
+ continue;
+ }
+
+ let block_k = global_k / BLOCK_SIZE;
+ let k_in_block = global_k % BLOCK_SIZE;
+
+ let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
+ let scale_idx = src0_idx * F16_PER_BLOCK;
+
+ let d = src0[scale_idx];
+ let dmin = src0[scale_idx + 1u];
+
+ // Load packed scales
+ var scale_vals: array<u32, 3>;
+ for (var i: u32 = 0u; i < 3u; i++) {
+ let scale_0 = src0[scale_idx + 2u + (2u*i)];
+ let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
+ scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
+ }
+
+ // The original loop processes elements in groups of 64
+ // Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4]
+ // But u increments EVERY 32 elements (after each l loop)
+ let group_of_64 = k_in_block / 64u; // 0-3
+ let pos_in_64 = k_in_block % 64u; // 0-63
+ let shift_group = pos_in_64 / 32u; // 0 or 1
+ let l = pos_in_64 % 32u; // 0-31
+
+ let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96
+ let shift = shift_group * 4u; // 0 or 4
+ let is = k_in_block / 32u; // 0-7
+
+ // u increments every 32 elements (0->1, 1->2, 2->4, 3->8, 4->16, 5->32, 6->64, 7->128)
+ let u_shift = k_in_block / 32u; // 0-7
+ let u: u32 = 1u << u_shift;
+
+ var sc: u32;
+ var mn: u32;
+
+ if (is < 4u) {
+ let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
+ let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
+ sc = sc_byte & 63u;
+ mn = min_byte & 63u;
+ } else {
+ let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
+ let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
+ let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
+
+ sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
+ mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
+ }
+
+ let dl = d * f16(sc);
+ let ml = dmin * f16(mn);
+
+ let q_idx = q_b_idx + l;
+ let q_0 = src0[scale_idx + 24u + 2u * (q_idx / 4u)];
+ let q_1 = src0[scale_idx + 24u + 2u * (q_idx / 4u) + 1u];
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
+
+ let q_byte = get_byte(q_packed, q_idx % 4u);
+
+ let qh_0 = src0[scale_idx + 8u + 2u * (l / 4u)];
+ let qh_1 = src0[scale_idx + 8u + 2u * (l / 4u) + 1u];
+ let qh_packed = bitcast<u32>(vec2(qh_0, qh_1));
+
+ let qh_byte = get_byte(qh_packed, l % 4u);
+
+ let qs_val = (q_byte >> shift) & 0xFu;
+ let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
+
+ let q_val = (f16(qs_val) + f16(qh_val)) * dl - ml;
+ shmem[elem_idx] = q_val;
+ }
+}
+
+#endif // INIT_SRC0_SHMEM_Q5_K
+
+#ifdef INIT_SRC0_SHMEM_Q6_K
+const BLOCK_SIZE = 256u;
+const F16_PER_BLOCK = 105u;
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+ for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
+ let tile_m = elem_idx / TILE_K;
+ let tile_k = elem_idx % TILE_K;
+
+ let global_m = offset_m + tile_m;
+ let global_k = k_outer + tile_k;
+
+ if (global_m >= params.m || global_k >= params.k) {
+ shmem[elem_idx] = f16(0.0);
+ continue;
+ }
+
+ let block_k = global_k / BLOCK_SIZE;
+ let k_in_block = global_k % BLOCK_SIZE;
+
+ let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
+ let scale_idx = src0_idx * F16_PER_BLOCK;
+
+ let half = k_in_block / 128u;
+ let pos_in_half = k_in_block % 128u;
+ let quarter = pos_in_half / 32u;
+ let l = pos_in_half % 32u;
+
+ let ql_b_idx = half * 64u;
+ let qh_b_idx = half * 32u;
+ let sc_b_idx = half * 8u;
+
+ // Load only ql13 word needed
+ let ql13_flat = ql_b_idx + l;
+ let ql13_word = ql13_flat / 4u;
+ let ql13 = bitcast<u32>(vec2(
+ src0[scale_idx + 2u * ql13_word],
+ src0[scale_idx + 2u * ql13_word + 1u]
+ ));
+ let ql13_b = get_byte(ql13, ql13_flat % 4u);
+
+ // Load only ql24 word needed
+ let ql24_flat = ql_b_idx + l + 32u;
+ let ql24_word = ql24_flat / 4u;
+ let ql24 = bitcast<u32>(vec2(
+ src0[scale_idx + 2u * ql24_word],
+ src0[scale_idx + 2u * ql24_word + 1u]
+ ));
+ let ql24_b = get_byte(ql24, ql24_flat % 4u);
+
+ // Load only qh word needed
+ let qh_flat = qh_b_idx + l;
+ let qh_word = qh_flat / 4u;
+ let qh = bitcast<u32>(vec2(
+ src0[scale_idx + 64u + 2u * qh_word],
+ src0[scale_idx + 64u + 2u * qh_word + 1u]
+ ));
+ let qh_b = get_byte(qh, qh_flat % 4u);
+
+ let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
+ let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0);
+ let q3 = f16((ql13_b >> 4u) | (((qh_b >> 4u) & 3u) << 4u)) - f16(32.0);
+ let q4 = f16((ql24_b >> 4u) | (((qh_b >> 6u) & 3u) << 4u)) - f16(32.0);
+
+ // Load only the scale word needed
+ let is = l / 16u;
+ let sc_idx = sc_b_idx + is + quarter * 2u;
+ let sc_word = sc_idx / 4u;
+ let sc = bitcast<u32>(vec2(
+ src0[scale_idx + 96u + 2u * sc_word],
+ src0[scale_idx + 96u + 2u * sc_word + 1u]
+ ));
+ let sc_val = get_byte_i32(sc, sc_idx % 4u);
+
+ let d = src0[scale_idx + 104u];
+
+ var q_val: f16;
+ if (quarter == 0u) {
+ q_val = q1;
+ } else if (quarter == 1u) {
+ q_val = q2;
+ } else if (quarter == 2u) {
+ q_val = q3;
+ } else {
+ q_val = q4;
+ }
+
+ shmem[elem_idx] = d * f16(sc_val) * q_val;
+ }
+}
+#endif // INIT_SRC0_SHMEM_Q6_K
const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
+
var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
-
enable f16;
#include "common_decls.tmpl"
}
#endif
+#ifdef MUL_ACC_Q4_1
+
+const BLOCK_SIZE = 32;
+const NQ = 16u; // number of weights per thread
+const F16_PER_BLOCK = 10u;
+const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
+
+fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
+ var local_sum = 0.0;
+ for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
+ let blck_idx = i / BLOCK_SIZE;
+ let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+ let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+ // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
+ let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+ let d = f32(src0[scale_idx]);
+ let m = f32(src0[scale_idx + 1u]);
+ for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+ let q_0 = src0[scale_idx + 2u + block_offset + j];
+ let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ for (var k: u32 = 0; k < 4; k++) {
+ let q_byte = get_byte(q_packed, k);
+ let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
+ let q_lo = f32(q_byte & 0xF) * d + m;
+ local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k];
+ local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16];
+ }
+ }
+ }
+ return local_sum;
+}
+#endif
+
+#ifdef MUL_ACC_Q5_0
+
+const BLOCK_SIZE = 32;
+const NQ = 16u; // number of weights per thread
+const F16_PER_BLOCK = 11u;
+const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
+
+fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
+ var local_sum = 0.0;
+ for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
+ let blck_idx = i / BLOCK_SIZE;
+ let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+ let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+ // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
+ let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+ let d = f32(src0[scale_idx]);
+ let qh0 = src0[scale_idx + 1u];
+ let qh1 = src0[scale_idx + 2u];
+ let qh_packed = bitcast<u32>(vec2(qh0, qh1));
+
+ for (var j = 0u; j < 2; j++) {
+ let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
+ let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
+
+ let j_adjusted = j + (block_offset / 2u);
+
+ for (var k: u32 = 0; k < 4; k++) {
+ let q_byte = get_byte(q_packed, k);
+
+ let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
+ let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
+ let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
+ let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
+
+ local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k];
+ local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16];
+ }
+
+ }
+ }
+ return local_sum;
+}
+#endif
+
+
+#ifdef MUL_ACC_Q5_1
+
+const BLOCK_SIZE = 32;
+const NQ = 16u; // number of weights per thread
+const F16_PER_BLOCK = 12u;
+const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
+
+fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
+ var local_sum = 0.0;
+ for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
+ let blck_idx = i / BLOCK_SIZE;
+ let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+ let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+ // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
+ let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+ let d = f32(src0[scale_idx]);
+ let m = src0[scale_idx + 1u];
+ let qh0 = src0[scale_idx + 2u];
+ let qh1 = src0[scale_idx + 3u];
+ let qh_packed = bitcast<u32>(vec2(qh0, qh1));
+
+ for (var j = 0u; j < 2; j++) {
+ let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
+ let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
+
+ let j_adjusted = j + (block_offset / 2u);
+
+ for (var k: u32 = 0; k < 4; k++) {
+ let q_byte = get_byte(q_packed, k);
+
+ let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
+ let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + f32(m);
+ let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
+ let q_lo = f32((q_byte & 0xF) | qh_lo) * d + f32(m);
+
+ local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k];
+ local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16];
+ }
+
+ }
+ }
+ return local_sum;
+}
+#endif
+
+
+#ifdef MUL_ACC_Q8_0
+
+const BLOCK_SIZE = 32;
+const NQ = 16u; // number of weights per thread
+const F16_PER_BLOCK = 17u;
+const WEIGHTS_PER_F16 = 2u;
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
+
+fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
+ var local_sum = 0.0;
+ for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
+ let blck_idx = i / BLOCK_SIZE;
+ let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+ let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+ // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
+ let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+ let d = f32(src0[scale_idx]);
+
+ for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+ let q_0 = src0[scale_idx + 1 + block_offset + j];
+ let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ for (var k: u32 = 0; k < 4; k++) {
+ let q_byte = get_byte_i32(q_packed, k);
+ let q_val = f32(q_byte) * d;
+ local_sum += q_val * shared_vector[shmem_idx + j * 2 + k];
+ }
+ }
+ }
+ return local_sum;
+}
+#endif
+
+
+#ifdef MUL_ACC_Q8_1
+
+const BLOCK_SIZE = 32;
+const NQ = 16u; // number of weights per thread
+const F16_PER_BLOCK = 18u;
+const WEIGHTS_PER_F16 = 2u;
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
+
+fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
+ var local_sum = 0.0;
+ for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
+ let blck_idx = i / BLOCK_SIZE;
+ let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+ let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+ // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
+ let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+ let d = f32(src0[scale_idx]);
+ let m = src0[scale_idx + 1u];
+
+ for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+ let q_0 = src0[scale_idx + 2u + block_offset + j];
+ let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ for (var k: u32 = 0; k < 4; k++) {
+ let q_byte = get_byte_i32(q_packed, k);
+ let q_val = f32(q_byte) * d + f32(m);
+ local_sum += q_val * shared_vector[shmem_idx + j * 2 + k];
+ }
+ }
+ }
+ return local_sum;
+}
+#endif
+
+#ifdef MUL_ACC_Q6_K
+
+const BLOCK_SIZE = 256u;
+const F16_PER_BLOCK = 105u;
+
+fn load_u32_at(bbase: u32, byte_offset: u32) -> u32 {
+ let aligned = byte_offset & ~3u;
+ let idx = bbase + aligned / 2u;
+ return bitcast<u32>(vec2(src0[idx], src0[idx + 1u]));
+}
+
+fn byte_of(v: u32, b: u32) -> u32 {
+ return (v >> (b * 8u)) & 0xFFu;
+}
+
+fn sbyte_of(v: u32, b: u32) -> i32 {
+ let raw = i32((v >> (b * 8u)) & 0xFFu);
+ return select(raw, raw - 256, raw >= 128);
+}
+
+fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
+ let tid = tig / 2u;
+ let ix = tig % 2u;
+ let ip = tid / 8u;
+ let il = tid % 8u;
+ let l0 = 4u * il;
+ let is = 8u * ip + l0 / 16u;
+
+ let y_offset = 128u * ip + l0;
+ let q_offset_l = 64u * ip + l0;
+ let q_offset_h = 32u * ip + l0;
+
+ let nb = tile_size / BLOCK_SIZE;
+ let k_block_start = k_outer / BLOCK_SIZE;
+
+ // Aligned scale byte position (is can be odd)
+ let sc_base_byte = 192u + (is & ~3u);
+ let sc_byte_pos = is & 3u;
+
+ var local_sum = 0.0;
+
+ for (var i = ix; i < nb; i += 2u) {
+ let bbase = (idx_base + k_block_start + i) * F16_PER_BLOCK;
+
+ let d_raw = load_u32_at(bbase, 208u);
+ let d = f32(bitcast<vec2<f16>>(d_raw)[0]);
+
+ let ql1_u32 = load_u32_at(bbase, q_offset_l);
+ let ql2_u32 = load_u32_at(bbase, q_offset_l + 32u);
+ let qh_u32 = load_u32_at(bbase, 128u + q_offset_h);
+ let sc_u32_0 = load_u32_at(bbase, sc_base_byte);
+ let sc_u32_1 = load_u32_at(bbase, sc_base_byte + 4u);
+
+ let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);
+ let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);
+ let sc4 = sbyte_of(sc_u32_1, sc_byte_pos);
+ let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u);
+
+ var sums = vec4<f32>(0.0, 0.0, 0.0, 0.0);
+
+ for (var l = 0u; l < 4u; l++) {
+ let y_base = i * BLOCK_SIZE + y_offset + l;
+ let yl0 = f32(shared_vector[y_base]);
+ let yl1 = f32(shared_vector[y_base + 32u]);
+ let yl2 = f32(shared_vector[y_base + 64u]);
+ let yl3 = f32(shared_vector[y_base + 96u]);
+
+ let q1b = byte_of(ql1_u32, l);
+ let q2b = byte_of(ql2_u32, l);
+ let qhb = byte_of(qh_u32, l);
+
+ let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32);
+ let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32);
+ let dq2 = f32(i32((q1b >> 4u) | ((qhb & 0x30u) )) - 32);
+ let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32);
+
+ sums[0] += yl0 * dq0;
+ sums[1] += yl1 * dq1;
+ sums[2] += yl2 * dq2;
+ sums[3] += yl3 * dq3;
+ }
+
+ local_sum += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) +
+ sums[2] * f32(sc4) + sums[3] * f32(sc6));
+ }
+
+ return local_sum;
+}
+#endif
+
struct MulMatParams {
offset_src0: u32,
offset_src1: u32,
dst[dst_idx / VEC_SIZE] = store_val(group_base);
}
}
-