GGML_OP_RMS_NORM,
GGML_OP_RMS_NORM_BACK,
GGML_OP_GROUP_NORM,
+ GGML_OP_L2_NORM,
GGML_OP_MUL_MAT,
GGML_OP_MUL_MAT_ID,
GGML_OP_ADD_REL_POS,
GGML_OP_RWKV_WKV6,
GGML_OP_GATED_LINEAR_ATTN,
+ GGML_OP_RWKV_WKV7,
GGML_OP_UNARY,
int n_groups,
float eps);
+ // l2 normalize along rows
+ // used in rwkv v7
+ GGML_API struct ggml_tensor * ggml_l2_norm(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float eps);
+
+ GGML_API struct ggml_tensor * ggml_l2_norm_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float eps);
+
// a - x
// b - dy
GGML_API struct ggml_tensor * ggml_rms_norm_back(
struct ggml_tensor * state,
float scale);
+ GGML_API struct ggml_tensor * ggml_rwkv_wkv7(
+ struct ggml_context * ctx,
+ struct ggml_tensor * r,
+ struct ggml_tensor * w,
+ struct ggml_tensor * k,
+ struct ggml_tensor * v,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * state);
+
// custom operators
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
}
}
+// ggml_compute_forward_l2_norm
+
+static void ggml_compute_forward_l2_norm_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ GGML_ASSERT(eps >= 0.0f);
+
+ // TODO: optimize
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
+ const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+ ggml_float sum = 0.0;
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
+ sum += (ggml_float)(x[i00] * x[i00]);
+ }
+
+ float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+ memcpy(y, x, ne00 * sizeof(float));
+
+ const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
+
+ ggml_vec_scale_f32(ne00, y, scale);
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_l2_norm(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_l2_norm_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ABORT("fatal error");
+ }
+ }
+}
+
// ggml_compute_forward_mul_mat
static void ggml_compute_forward_mul_mat_one_chunk(
}
}
+// ggml_compute_forward_rwkv_wkv7
+
+static void ggml_compute_forward_rwkv_wkv7_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+ const int64_t T = dst->src[1]->ne[2];
+ const int64_t C = dst->ne[0];
+ const int64_t HEADS = dst->src[1]->ne[1];
+ const int64_t n_seqs = dst->src[6]->ne[1];
+ const int64_t head_size = C / HEADS;
+
+ float * dst_data = (float *) dst->data;
+ float * state = ((float *) dst->data) + C * T;
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ if (ith >= HEADS) {
+ return;
+ }
+
+ const int h_start = (HEADS * ith) / nth;
+ const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
+ (HEADS * (ith + 1)) / nth : HEADS;
+
+ float * r = (float *) dst->src[0]->data;
+ float * w = (float *) dst->src[1]->data;
+ float * k = (float *) dst->src[2]->data;
+ float * v = (float *) dst->src[3]->data;
+ float * a = (float *) dst->src[4]->data;
+ float * b = (float *) dst->src[5]->data;
+
+ int64_t t_stride = HEADS * head_size; // Same to C
+
+ int64_t h_stride = C / HEADS;
+ GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
+ int64_t h_stride_2d = head_size * head_size;
+
+ #if defined(GGML_SIMD)
+ for (int64_t t = 0; t < T; t++) {
+ int64_t t_offset = t * t_stride;
+ int64_t state_offset = head_size * C * (t / (T / n_seqs));
+ float * state_cur = state + state_offset;
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
+
+ for (int64_t h = h_start; h < h_end; h++) {
+ int64_t h_offset = h * h_stride;
+ int64_t t_h_offset = t_offset + h_offset;
+ int64_t h_2d_offset = h * h_stride_2d;
+
+ for (int64_t ii = 0; ii < head_size; ii++) {
+ int64_t t_h_i_offset = t_h_offset + ii;
+ int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
+
+ GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
+
+ float sa = 0;
+ {
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
+ GGML_F32_VEC ax[GGML_F32_ARR];
+ GGML_F32_VEC ay[GGML_F32_ARR];
+ for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
+ for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
+ ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
+ ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
+ sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
+ }
+ }
+ GGML_F32_VEC_REDUCE(sa, sum);
+ }
+
+ GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
+
+ int64_t j = 0;
+ GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
+ for (; j < head_size; j += GGML_F32_STEP) {
+ for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
+ int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
+
+ GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
+ GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
+ GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
+ GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
+
+ k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
+
+ GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
+ // kv + s * decay + sa * b
+ state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
+ state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
+ GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
+
+ result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
+ }
+ }
+ GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
+
+ // There shouldn't be left-overs though.
+ for (; j < head_size; j++) {
+ int64_t t_h_j_offset = t_h_offset + j;
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j;
+
+ float r_val = r[t_h_j_offset];
+ float w_val = w[t_h_j_offset];
+ float k_val = k[t_h_j_offset];
+ float b_val = b[t_h_j_offset];
+ float kv_val = v[t_h_i_offset] * k_val;
+
+ float prev_state_val = state_prev[h_2d_i_j_offset];
+ state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
+ dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
+ }
+ }
+ }
+ }
+ #else
+ for (int64_t t = 0; t < T; t++) {
+ int64_t t_offset = t * t_stride;
+ int64_t state_offset = head_size * C * (t / (T / n_seqs));
+ float * state_cur = state + state_offset;
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
+
+ for (int64_t h = h_start; h < h_end; h++) {
+ int64_t h_offset = h * h_stride;
+ int64_t t_h_offset = t_offset + h_offset;
+ int64_t h_2d_offset = h * h_stride_2d;
+
+ for (int64_t i = 0; i < head_size; i++) {
+ int64_t t_h_i_offset = t_h_offset + i;
+ int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+ float v_val = v[t_h_i_offset];
+
+ float sa = 0, result = 0;
+ for (int64_t j = 0; j < head_size; j++) {
+ sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
+ }
+
+ for (int64_t j = 0; j < head_size; j++) {
+ int64_t t_h_j_offset = t_h_offset + j;
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j;
+
+ float r_val = r[t_h_j_offset];
+ float w_val = w[t_h_j_offset];
+ float k_val = k[t_h_j_offset];
+ float b_val = b[t_h_j_offset];
+ float kv_val = v_val * k_val;
+ float prev_state_val = state_prev[h_2d_i_j_offset];
+ state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
+ result += state_cur[h_2d_i_j_offset] * r_val;
+ }
+ dst_data[t_h_i_offset] = result;
+ }
+ }
+ }
+ #endif
+}
+
+
+static void ggml_compute_forward_rwkv_wkv7(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_rwkv_wkv7_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ABORT("fatal error");
+ }
+ }
+}
+
// ggml_compute_forward_map_unary
static void ggml_compute_forward_map_unary_f32(
{
ggml_compute_forward_group_norm(params, tensor);
} break;
+ case GGML_OP_L2_NORM:
+ {
+ ggml_compute_forward_l2_norm(params, tensor);
+ } break;
case GGML_OP_MUL_MAT:
{
ggml_compute_forward_mul_mat(params, tensor);
{
ggml_compute_forward_gla(params, tensor);
} break;
+ case GGML_OP_RWKV_WKV7:
+ {
+ ggml_compute_forward_rwkv_wkv7(params, tensor);
+ } break;
case GGML_OP_MAP_UNARY:
{
ggml_unary_op_f32_t fun;
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
+ case GGML_OP_L2_NORM:
case GGML_OP_GROUP_NORM:
case GGML_OP_CONCAT:
case GGML_OP_MUL_MAT:
case GGML_OP_FLASH_ATTN_BACK:
case GGML_OP_SSM_CONV:
case GGML_OP_SSM_SCAN:
+ case GGML_OP_RWKV_WKV6:
+ case GGML_OP_GATED_LINEAR_ATTN:
+ case GGML_OP_RWKV_WKV7:
{
n_tasks = n_threads;
} break;
case GGML_OP_WIN_PART:
case GGML_OP_WIN_UNPART:
case GGML_OP_GET_REL_POS:
- case GGML_OP_RWKV_WKV6:
- case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY:
case GGML_OP_MAP_CUSTOM1_F32:
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
-#include "ggml-cuda/wkv6.cuh"
+#include "ggml-cuda/wkv.cuh"
#include "ggml-cuda/gla.cuh"
#include "ggml.h"
case GGML_OP_GROUP_NORM:
ggml_cuda_op_group_norm(ctx, dst);
break;
+ case GGML_OP_L2_NORM:
+ ggml_cuda_op_l2_norm(ctx, dst);
+ break;
case GGML_OP_CONCAT:
ggml_cuda_op_concat(ctx, dst);
break;
case GGML_OP_GATED_LINEAR_ATTN:
ggml_cuda_op_gated_linear_attn(ctx, dst);
break;
+ case GGML_OP_RWKV_WKV7:
+ ggml_cuda_op_rwkv_wkv7(ctx, dst);
+ break;
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
ggml_cuda_cross_entropy_loss_back(ctx, dst);
break;
break;
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
+ case GGML_OP_L2_NORM:
return true;
case GGML_OP_RMS_NORM_BACK:
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
case GGML_OP_LEAKY_RELU:
case GGML_OP_RWKV_WKV6:
case GGML_OP_GATED_LINEAR_ATTN:
+ case GGML_OP_RWKV_WKV7:
return true;
case GGML_OP_FLASH_ATTN_EXT: {
#ifndef FLASH_ATTN_AVAILABLE
}
}
+// template <int block_size>
+// static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
+// const int row = blockIdx.x*blockDim.y + threadIdx.y;
+// const int tid = threadIdx.x;
+
+// float tmp = 0.0f; // partial sum for thread in warp
+
+// for (int col = tid; col < ncols; col += block_size) {
+// const float xi = x[row*ncols + col];
+// tmp += xi * xi;
+// }
+
+// // sum up partial sums
+// tmp = warp_reduce_sum(tmp);
+// if (block_size > WARP_SIZE) {
+// __shared__ float s_sum[32];
+// int warp_id = threadIdx.x / WARP_SIZE;
+// int lane_id = threadIdx.x % WARP_SIZE;
+// if (lane_id == 0) {
+// s_sum[warp_id] = tmp;
+// }
+// __syncthreads();
+// tmp = s_sum[lane_id];
+// tmp = warp_reduce_sum(tmp);
+// }
+
+// // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
+// const float scale = rsqrtf(fmaxf(tmp, eps * eps));
+
+// for (int col = tid; col < ncols; col += block_size) {
+// dst[row*ncols + col] = scale * x[row*ncols + col];
+// }
+// }
+
+template <int block_size>
+static __global__ void l2_norm_f32(
+ const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
+ const int64_t stride_sample, const float eps) {
+ const int nrows = gridDim.x;
+ const int nchannels = gridDim.y;
+
+ const int row = blockIdx.x;
+ const int channel = blockIdx.y;
+ const int sample = blockIdx.z;
+ const int tid = threadIdx.x;
+
+ x += sample*stride_sample + channel*stride_channel + row*stride_row;
+ dst += ((sample*nchannels + channel)*nrows + row)*ncols;
+
+ float tmp = 0.0f; // partial sum for thread in warp
+
+ for (int col = tid; col < ncols; col += block_size) {
+ const float xi = x[col];
+ tmp += xi * xi;
+ }
+
+ // sum up partial sums
+ tmp = warp_reduce_sum(tmp);
+ if constexpr (block_size > WARP_SIZE) {
+ static_assert(block_size == 1024, "unexpected block_size");
+ __shared__ float s_sum[32];
+ const int warp_id = threadIdx.x / WARP_SIZE;
+ const int lane_id = threadIdx.x % WARP_SIZE;
+ if (lane_id == 0) {
+ s_sum[warp_id] = tmp;
+ }
+ __syncthreads();
+ tmp = s_sum[lane_id];
+ tmp = warp_reduce_sum(tmp);
+ }
+
+ // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
+ const float scale = rsqrtf(fmaxf(tmp, eps * eps));
+
+ for (int col = tid; col < ncols; col += block_size) {
+ dst[col] = scale * x[col];
+ }
+}
+
static void norm_f32_cuda(
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
}
}
+static void l2_norm_f32_cuda(
+ const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
+ const dim3 blocks_num(nrows, nchannels, nsamples);
+ if (ncols < 1024) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ l2_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+ }
+}
+
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
}
+
+void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *) src0->data;
+ float * dst_d = (float *) dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ GGML_TENSOR_UNARY_OP_LOCALS;
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+ GGML_ASSERT(eps >= 0.0f);
+
+ const size_t ts0 = ggml_type_size(src0->type);
+ GGML_ASSERT(nb00 == ts0);
+ const int64_t s01 = nb01 / ts0;
+ const int64_t s02 = nb02 / ts0;
+ const int64_t s03 = nb03 / ts0;
+
+ l2_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
+}
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
--- /dev/null
+#include "common.cuh"
+#include "wkv.cuh"
+
+template <int block_size>
+static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
+ const int tid = threadIdx.x;
+ const int bid = blockIdx.x;
+
+ const int head_size = block_size;
+ const int batch_i = bid / H;
+ const int head_i = bid % H;
+ const int state_size = C * head_size;
+ const int n_seq_tokens = T / B;
+
+ float state[head_size];
+ __shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
+
+ #pragma unroll
+ for (int i = 0; i < head_size; i++) {
+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
+ }
+
+ __syncthreads();
+ _tf[tid] = tf[head_i * head_size + tid];
+ __syncthreads();
+
+ for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
+ __syncthreads();
+ _k[tid] = k[t];
+ _r[tid] = r[t];
+ _td[tid] = td[t];
+ __syncthreads();
+
+ const float _v = v[t];
+ float y = 0;
+ for (int j = 0; j < head_size; j += 4) {
+ const float4& k = (float4&)(_k[j]);
+ const float4& r = (float4&)(_r[j]);
+ const float4& tf = (float4&)(_tf[j]);
+ const float4& td = (float4&)(_td[j]);
+ float4& s = (float4&)(state[j]);
+ float4 kv;
+
+ kv.x = k.x * _v;
+ kv.y = k.y * _v;
+ kv.z = k.z * _v;
+ kv.w = k.w * _v;
+
+ y += r.x * (tf.x * kv.x + s.x);
+ y += r.y * (tf.y * kv.y + s.y);
+ y += r.z * (tf.z * kv.z + s.z);
+ y += r.w * (tf.w * kv.w + s.w);
+
+ s.x = s.x * td.x + kv.x;
+ s.y = s.y * td.y + kv.y;
+ s.z = s.z * td.z + kv.z;
+ s.w = s.w * td.w + kv.w;
+ }
+ dst[t] = y;
+ }
+
+ #pragma unroll
+ for (int i = 0; i < head_size; i++) {
+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
+ }
+}
+
+template <int block_size>
+static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, const int H, const float * r, const float * w, const float * k, const float * v, const float * a, const float * b, const float * s, float * dst) {
+ const int tid = threadIdx.x;
+ const int bid = blockIdx.x;
+
+ const int head_size = block_size;
+ const int batch_i = bid / H;
+ const int head_i = bid % H;
+ const int state_size = C * head_size;
+ const int n_seq_tokens = T / B;
+
+ float state[head_size];
+ __shared__ float _r[head_size], _w[head_size], _k[head_size], _a[head_size], _b[head_size];
+
+#ifndef GGML_USE_MUSA
+ #pragma unroll
+#endif
+ for (int i = 0; i < head_size; i++) {
+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
+ }
+
+ for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
+ __syncthreads();
+ _r[tid] = r[t];
+ _w[tid] = w[t];
+ _k[tid] = k[t];
+ _a[tid] = a[t];
+ _b[tid] = b[t];
+ __syncthreads();
+
+ float sa = 0;
+ #pragma unroll
+ for (int j = 0; j < head_size; j += 4)
+ {
+ const float4& a = (float4&)(_a[j]);
+ const float4& s = (float4&)(state[j]);
+ sa += a.x * s.x;
+ sa += a.y * s.y;
+ sa += a.z * s.z;
+ sa += a.w * s.w;
+ }
+
+ const float _v = v[t];
+ float y = 0;
+ for (int j = 0; j < head_size; j += 4) {
+ const float4& r = (float4&)(_r[j]);
+ const float4& w = (float4&)(_w[j]);
+ const float4& k = (float4&)(_k[j]);
+ const float4& b = (float4&)(_b[j]);
+ float4& s = (float4&)(state[j]);
+ float4 kv;
+
+ kv.x = k.x * _v;
+ kv.y = k.y * _v;
+ kv.z = k.z * _v;
+ kv.w = k.w * _v;
+
+ s.x = s.x * w.x + kv.x + sa * b.x;
+ s.y = s.y * w.y + kv.y + sa * b.y;
+ s.z = s.z * w.z + kv.z + sa * b.z;
+ s.w = s.w * w.w + kv.w + sa * b.w;
+
+ y += s.x * r.x;
+ y += s.y * r.y;
+ y += s.z * r.z;
+ y += s.w * r.w;
+ }
+ dst[t] = y;
+ }
+
+ #pragma unroll
+ for (int i = 0; i < head_size; i++) {
+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
+ }
+}
+
+void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const float * k_d = (const float *)dst->src[0]->data;
+ const float * v_d = (const float *)dst->src[1]->data;
+ const float * r_d = (const float *)dst->src[2]->data;
+ const float * tf_d = (const float *)dst->src[3]->data;
+ const float * td_d = (const float *)dst->src[4]->data;
+ const float * s_d = (const float *)dst->src[5]->data;
+
+ const int64_t B = dst->src[5]->ne[1];
+ const int64_t T = dst->src[0]->ne[2];
+ const int64_t C = dst->ne[0];
+ const int64_t H = dst->src[0]->ne[1];
+
+ float * dst_d = (float *)dst->data;
+
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
+ GGML_ASSERT(C % H == 0);
+ GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
+
+ if (C / H == CUDA_WKV_BLOCK_SIZE) {
+ rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
+ } else {
+ rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
+ }
+}
+
+void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const float * r_d = (const float *)dst->src[0]->data;
+ const float * w_d = (const float *)dst->src[1]->data;
+ const float * k_d = (const float *)dst->src[2]->data;
+ const float * v_d = (const float *)dst->src[3]->data;
+ const float * a_d = (const float *)dst->src[4]->data;
+ const float * b_d = (const float *)dst->src[5]->data;
+ const float * s_d = (const float *)dst->src[6]->data;
+
+ const int64_t B = dst->src[6]->ne[1];
+ const int64_t T = dst->src[0]->ne[2];
+ const int64_t C = dst->ne[0];
+ const int64_t H = dst->src[0]->ne[1];
+
+ float * dst_d = (float *)dst->data;
+
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
+ GGML_ASSERT(C % H == 0);
+ GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
+
+ if (C / H == CUDA_WKV_BLOCK_SIZE) {
+ rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
+ } else {
+ rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
+ }
+}
--- /dev/null
+#include "common.cuh"
+
+#define CUDA_WKV_BLOCK_SIZE 64
+
+void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
float eps;
} ggml_metal_kargs_rms_norm;
+typedef struct {
+ int32_t ne00;
+ int32_t ne00_4;
+ uint64_t nb01;
+ float eps;
+} ggml_metal_kargs_l2_norm;
+
typedef struct {
int64_t ne00;
int64_t ne01;
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
GGML_METAL_KERNEL_TYPE_RMS_NORM,
+ GGML_METAL_KERNEL_TYPE_L2_NORM,
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
GGML_METAL_KERNEL_TYPE_NORM,
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
+ GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
+ GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
case GGML_OP_RMS_NORM:
+ case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
case GGML_OP_ARGMAX:
return true;
return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
case GGML_OP_SSM_CONV:
case GGML_OP_SSM_SCAN:
+ case GGML_OP_RWKV_WKV6:
+ case GGML_OP_RWKV_WKV7:
return true;
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
+ case GGML_OP_RWKV_WKV6:
+ {
+ const int64_t B = dst->src[5]->ne[1];
+ const int64_t T = dst->src[0]->ne[2];
+ const int64_t C = dst->ne[0];
+ const int64_t H = dst->src[0]->ne[1];
+
+ GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
+ GGML_ASSERT(C % H == 0);
+ GGML_ASSERT(C / H == 64);
+
+ size_t offs_src3 = 0;
+ size_t offs_src4 = 0;
+ size_t offs_src5 = 0;
+
+ id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
+ id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
+ id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
+
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
+ [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
+ [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
+
+ [encoder setBytes:&B length:sizeof(B) atIndex:7];
+ [encoder setBytes:&T length:sizeof(T) atIndex:8];
+ [encoder setBytes:&C length:sizeof(C) atIndex:9];
+ [encoder setBytes:&H length:sizeof(H) atIndex:10];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
+ } break;
+ case GGML_OP_RWKV_WKV7:
+ {
+ const int64_t B = dst->src[6]->ne[1];
+ const int64_t T = dst->src[0]->ne[2];
+ const int64_t C = dst->ne[0];
+ const int64_t H = dst->src[0]->ne[1];
+
+ GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
+ GGML_ASSERT(C % H == 0);
+ GGML_ASSERT(C / H == 64);
+
+ size_t offs_src3 = 0;
+ size_t offs_src4 = 0;
+ size_t offs_src5 = 0;
+ size_t offs_src6 = 0;
+
+ id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
+ id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
+ id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
+ id<MTLBuffer> id_src6 = dst->src[6] ? ggml_metal_get_buffer(dst->src[6], &offs_src6) : nil;
+
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
+ [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
+ [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
+ [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:7];
+
+ [encoder setBytes:&B length:sizeof(B) atIndex:8];
+ [encoder setBytes:&T length:sizeof(T) atIndex:9];
+ [encoder setBytes:&C length:sizeof(C) atIndex:10];
+ [encoder setBytes:&H length:sizeof(H) atIndex:11];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
+ } break;
case GGML_OP_MUL_MAT:
{
GGML_ASSERT(ne00 == ne10);
const int64_t nrows = ggml_nrows(src0);
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_L2_NORM:
+ {
+ GGML_ASSERT(ne00 % 4 == 0);
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_L2_NORM].pipeline;
+
+ int nth = 32; // SIMD width
+
+ while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
+ nth *= 2;
+ }
+
+ nth = MIN(nth, ne00/4);
+
+ ggml_metal_kargs_l2_norm args = {
+ /*.ne00 =*/ ne00,
+ /*.ne00_4 =*/ ne00/4,
+ /*.nb01 =*/ nb01,
+ /*.eps =*/ eps,
+ };
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
+
+ const int64_t nrows = ggml_nrows(src0);
+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_GROUP_NORM:
}
}
+kernel void kernel_rwkv_wkv6_f32(
+ device const float * k,
+ device const float * v,
+ device const float * r,
+ device const float * tf,
+ device const float * td,
+ device const float * state_in,
+ device float * dst,
+ constant uint & B,
+ constant uint & T,
+ constant uint & C,
+ constant uint & H,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+
+ const uint head_size = 64; // TODO: support head_size = 128
+ const uint batch_id = tgpig.x / H;
+ const uint head_id = tgpig.x % H;
+ const uint tid = tpitg.x;
+
+ if (batch_id >= B || head_id >= H) {
+ return;
+ }
+
+ const uint state_size = C * head_size;
+ const uint n_seq_tokens = T / B;
+
+ threadgroup float _k[head_size];
+ threadgroup float _r[head_size];
+ threadgroup float _tf[head_size];
+ threadgroup float _td[head_size];
+
+ float state[head_size];
+
+ for (uint i = 0; i < head_size; i++) {
+ state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+ + i * head_size + tid];
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ _tf[tid] = tf[head_id * head_size + tid];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
+ const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
+
+ for (uint t = start_t; t < end_t; t += C) {
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ _k[tid] = k[t];
+ _r[tid] = r[t];
+ _td[tid] = td[t];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ const float v_val = v[t];
+ float y = 0.0;
+
+ for (uint j = 0; j < head_size; j += 4) {
+ float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
+ float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
+ float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
+ float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
+ float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
+
+ float4 kv = k_vec * v_val;
+
+ float4 temp = tf_vec * kv + s_vec;
+ y += dot(r_vec, temp);
+
+ s_vec = s_vec * td_vec + kv;
+ state[j] = s_vec[0];
+ state[j+1] = s_vec[1];
+ state[j+2] = s_vec[2];
+ state[j+3] = s_vec[3];
+ }
+
+ dst[t] = y;
+ }
+
+ for (uint i = 0; i < head_size; i++) {
+ dst[T * C + batch_id * state_size + head_id * head_size * head_size
+ + i * head_size + tid] = state[i];
+ }
+}
+
+kernel void kernel_rwkv_wkv7_f32(
+ device const float * r,
+ device const float * w,
+ device const float * k,
+ device const float * v,
+ device const float * a,
+ device const float * b,
+ device const float * state_in,
+ device float * dst,
+ constant uint & B,
+ constant uint & T,
+ constant uint & C,
+ constant uint & H,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+
+ const uint head_size = 64; // TODO: support head_size = 128
+ const uint batch_id = tgpig.x / H;
+ const uint head_id = tgpig.x % H;
+ const uint tid = tpitg.x;
+
+ if (batch_id >= B || head_id >= H) {
+ return;
+ }
+
+ const uint state_size = C * head_size;
+ const uint n_seq_tokens = T / B;
+
+ threadgroup float _r[head_size];
+ threadgroup float _w[head_size];
+ threadgroup float _k[head_size];
+ threadgroup float _a[head_size];
+ threadgroup float _b[head_size];
+
+ float state[head_size];
+
+ for (uint i = 0; i < head_size; i++) {
+ state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+ + tid * head_size + i];
+ }
+
+ const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
+ const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
+
+ for (uint t = start_t; t < end_t; t += C) {
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ _r[tid] = r[t];
+ _w[tid] = w[t];
+ _k[tid] = k[t];
+ _a[tid] = a[t];
+ _b[tid] = b[t];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ const float v_val = v[t];
+ float y = 0.0, sa = 0.0;
+
+ float4 sa_vec(0.0);
+
+ for (int j = 0; j < head_size; j += 4) {
+ float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
+ float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
+ sa_vec += a_vec * s_vec;
+ }
+ sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
+
+ for (uint j = 0; j < head_size; j += 4) {
+ float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
+ float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
+ float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
+ float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
+ float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
+
+ float4 kv = k_vec * v_val;
+
+ s_vec = s_vec * w_vec + kv + sa * b_vec;
+ y += dot(s_vec, r_vec);
+
+ state[j] = s_vec[0];
+ state[j+1] = s_vec[1];
+ state[j+2] = s_vec[2];
+ state[j+3] = s_vec[3];
+ }
+
+ dst[t] = y;
+ }
+
+ for (uint i = 0; i < head_size; i++) {
+ dst[T * C + batch_id * state_size + head_id * head_size * head_size
+ + tid * head_size + i] = state[i];
+ }
+}
+
kernel void kernel_argmax(
device const void * x,
device int32_t * dst,
}
}
+kernel void kernel_l2_norm(
+ constant ggml_metal_kargs_l2_norm & args,
+ device const char * src0,
+ device char * dst,
+ threadgroup float * shmem_f32 [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ ushort tpitg[[thread_position_in_threadgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort ntg[[threads_per_threadgroup]]) {
+ if (sgitg == 0) {
+ shmem_f32[tiisg] = 0.0f;
+ }
+
+ device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
+
+ float sumf = 0.0f;
+
+ // parallel sum
+ for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+ sumf += dot(x[i00], x[i00]);
+ }
+ sumf = simd_sum(sumf);
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ shmem_f32[sgitg] = sumf;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sumf = shmem_f32[tiisg];
+ sumf = simd_sum(sumf);
+
+ const float scale = 1.0f/sqrt(max(sumf, args.eps));
+
+ device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
+ for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+ y[i00] = x[i00] * scale;
+ }
+}
+
kernel void kernel_group_norm(
device const float * src0,
device float * dst,
#include "softmax.hpp"
#include "tsembd.hpp"
#include "im2col.hpp"
-#include "wkv6.hpp"
+#include "wkv.hpp"
#include "outprod.hpp"
#include "element_wise.hpp"
#include "cpy.hpp"
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
+static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+ GGML_SYCL_DEBUG("call %s\n", __func__);
+ ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_l2_norm);
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm);
case GGML_OP_RMS_NORM:
ggml_sycl_rms_norm(ctx, dst);
break;
+ case GGML_OP_L2_NORM:
+ ggml_sycl_l2_norm(ctx, dst);
+ break;
case GGML_OP_MUL_MAT:
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
return false;
case GGML_OP_RWKV_WKV6:
ggml_sycl_op_rwkv_wkv6(ctx, dst);
break;
+ case GGML_OP_RWKV_WKV7:
+ ggml_sycl_op_rwkv_wkv7(ctx, dst);
+ break;
case GGML_OP_GATED_LINEAR_ATTN:
ggml_sycl_op_gated_linear_attn(ctx, dst);
break;
return (op->src[0]->type == GGML_TYPE_F32);
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
+ case GGML_OP_L2_NORM:
case GGML_OP_GROUP_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_SCALE:
case GGML_OP_LEAKY_RELU:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_RWKV_WKV6:
+ case GGML_OP_RWKV_WKV7:
case GGML_OP_GATED_LINEAR_ATTN:
return true;
default:
}
}
+static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps,
+ const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+ item_ct1.get_local_id(1);
+ const int tid = item_ct1.get_local_id(2);
+ const int nthreads = item_ct1.get_local_range(2);
+ const int nwarps = nthreads / WARP_SIZE;
+ float tmp = 0.0f; // partial sum for thread in warp
+
+ for (int col = tid; col < ncols; col += block_size) {
+ const float xi = x[row * ncols + col];
+ tmp += xi * xi;
+ }
+
+ // sum up partial sums
+ tmp = warp_reduce_sum(tmp, item_ct1);
+ if (block_size > WARP_SIZE) {
+
+ int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
+ int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
+ if (lane_id == 0) {
+ s_sum[warp_id] = tmp;
+ }
+ /*
+ DPCT1118:3: SYCL group functions and algorithms must be encountered in
+ converged control flow. You may need to adjust the code.
+ */
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+ size_t nreduce = nwarps / WARP_SIZE;
+ tmp = 0.f;
+ for (size_t i = 0; i < nreduce; i += 1)
+ {
+ tmp += s_sum[lane_id + i * WARP_SIZE];
+ }
+ tmp = warp_reduce_sum(tmp, item_ct1);
+ }
+
+ const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps));
+
+ for (int col = tid; col < ncols; col += block_size) {
+ dst[row * ncols + col] = scale * x[row * ncols + col];
+ }
+}
+
static void norm_f32_sycl(const float* x, float* dst, const int ncols,
const int nrows, const float eps,
queue_ptr stream, int device) {
}
}
+static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
+ const int nrows, const float eps,
+ queue_ptr stream, int device) {
+ GGML_ASSERT(ncols % WARP_SIZE == 0);
+ // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
+ if (ncols < 1024) {
+ const sycl::range<3> block_dims(1, 1, WARP_SIZE);
+ stream->submit([&](sycl::handler& cgh) {
+ cgh.parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
+ block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ l2_norm_f32(x, dst, ncols, eps, item_ct1,
+ nullptr, WARP_SIZE);
+ });
+ });
+ }
+ else {
+ const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
+ assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
+ const sycl::range<3> block_dims(1, 1, work_group_size);
+ /*
+ DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+ stream->submit([&](sycl::handler& cgh) {
+ sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
+ cgh);
+ cgh.parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
+ block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ l2_norm_f32(x, dst, ncols, eps, item_ct1,
+ get_pointer(s_sum_acc_ct1), work_group_size);
+ });
+ });
+ }
+}
+
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
ggml_tensor* dst, const float* src0_dd,
const float* src1_dd, float* dst_dd,
(void)dst;
(void)src1_dd;
}
+
+void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
+ const ggml_tensor* src1, ggml_tensor* dst,
+ const float* src0_dd, const float* src1_dd,
+ float* dst_dd,
+ const queue_ptr& main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t nrows = ggml_nrows(src0);
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
+
+ (void)src1;
+ (void)dst;
+ (void)src1_dd;
+}
float* dst_dd,
const queue_ptr& main_stream);
+void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
+ const ggml_tensor* src1, ggml_tensor* dst,
+ const float* src0_dd, const float* src1_dd,
+ float* dst_dd,
+ const queue_ptr& main_stream);
+
#endif // GGML_SYCL_NORM_HPP
--- /dev/null
+#include <sycl/sycl.hpp>
+#include "wkv.hpp"
+
+constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
+
+// Helper function for the main kernel
+template <int block_size>
+static void rwkv_wkv6_f32_kernel(
+ const int B, const int T, const int C, const int H,
+ const float* k, const float* v, const float* r,
+ const float* tf, const float* td, const float* s,
+ float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
+
+ const int tid = item_ct1.get_local_id(2);
+ const int bid = item_ct1.get_group(2);
+
+ const int head_size = block_size;
+ const int batch_i = bid / H;
+ const int head_i = bid % H;
+ const int state_size = C * head_size;
+ const int n_seq_tokens = T / B;
+
+ // Set up shared memory pointers
+ float* _k = shared_mem;
+ float* _r = _k + head_size;
+ float* _tf = _r + head_size;
+ float* _td = _tf + head_size;
+
+ // Local state array
+ float state[block_size];
+
+ // Load initial state
+ #pragma unroll
+ for (int i = 0; i < head_size; i++) {
+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
+ }
+
+ // Sync threads before shared memory operations
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+
+ // Load time-mixing parameters
+ _tf[tid] = tf[head_i * head_size + tid];
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+
+ // Main sequence processing loop
+ for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
+ t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
+ t += C) {
+
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+
+ // Load current timestep data to shared memory
+ _k[tid] = k[t];
+ _r[tid] = r[t];
+ _td[tid] = td[t];
+
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+
+ const float _v = v[t];
+ float y = 0;
+
+ // Process in chunks of 4 for better vectorization
+ sycl::float4 k4, r4, tf4, td4, s4;
+ #pragma unroll
+ for (int j = 0; j < head_size; j += 4) {
+ // Load data in vec4 chunks
+ k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
+ r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
+ tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
+ td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
+ s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
+
+ // Compute key-value product
+ sycl::float4 kv4 = k4 * _v;
+
+ // Accumulate weighted sum
+ y += sycl::dot(r4, tf4 * kv4 + s4);
+
+ // Update state
+ s4 = s4 * td4 + kv4;
+
+ // Store updated state
+ state[j] = s4.x();
+ state[j+1] = s4.y();
+ state[j+2] = s4.z();
+ state[j+3] = s4.w();
+ }
+
+ dst[t] = y;
+ }
+
+ // Save final state
+ #pragma unroll
+ for (int i = 0; i < head_size; i++) {
+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
+ }
+}
+
+template <int block_size>
+static void rwkv_wkv7_f32_kernel(
+ const int B, const int T, const int C, const int H,
+ const float* r, const float* w, const float* k, const float* v,
+ const float* a, const float* b, const float* s,
+ float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
+
+ const int tid = item_ct1.get_local_id(2);
+ const int bid = item_ct1.get_group(2);
+
+ const int head_size = block_size;
+ const int batch_i = bid / H;
+ const int head_i = bid % H;
+ const int state_size = C * head_size;
+ const int n_seq_tokens = T / B;
+
+ float* _r = shared_mem;
+ float* _w = _r + head_size;
+ float* _k = _w + head_size;
+ float* _a = _k + head_size;
+ float* _b = _a + head_size;
+
+ float state[block_size];
+
+ #pragma unroll
+ for (int i = 0; i < head_size; i++) {
+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
+ }
+
+ for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
+ t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
+ t += C) {
+
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+
+ _r[tid] = r[t];
+ _w[tid] = w[t];
+ _k[tid] = k[t];
+ _a[tid] = a[t];
+ _b[tid] = b[t];
+
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+
+ const float _v = v[t];
+ float y = 0, sa = 0;
+ sycl::float4 a4, s4;
+
+ #pragma unroll
+ for (int j = 0; j < head_size; j += 4) {
+ a4 = sycl::float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
+ s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
+ sa += sycl::dot(a4, s4);
+ }
+
+ sycl::float4 r4, w4, k4, b4;
+ #pragma unroll
+ for (int j = 0; j < head_size; j += 4) {
+ r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
+ w4 = sycl::float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
+ k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
+ b4 = sycl::float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
+ s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
+
+ sycl::float4 kv4 = k4 * _v;
+
+ s4 = s4 * w4 + kv4 + sa * b4;
+ y += sycl::dot(r4, s4);
+
+ state[j] = s4.x();
+ state[j+1] = s4.y();
+ state[j+2] = s4.z();
+ state[j+3] = s4.w();
+ }
+
+ dst[t] = y;
+ }
+
+ #pragma unroll
+ for (int i = 0; i < head_size; i++) {
+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
+ }
+}
+
+void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
+
+ const ggml_tensor *src0 = dst->src[0];
+ const ggml_tensor *src1 = dst->src[1];
+
+ const float* k_d = (const float*)dst->src[0]->data;
+ const float* v_d = (const float*)dst->src[1]->data;
+ const float* r_d = (const float*)dst->src[2]->data;
+ const float* tf_d = (const float*)dst->src[3]->data;
+ const float* td_d = (const float*)dst->src[4]->data;
+ const float* s_d = (const float*)dst->src[5]->data;
+ float* dst_d = (float*)dst->data;
+
+ const int64_t B = dst->src[5]->ne[1];
+ const int64_t T = dst->src[0]->ne[2];
+ const int64_t C = dst->ne[0];
+ const int64_t H = dst->src[0]->ne[1];
+
+ GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
+ GGML_ASSERT(C % H == 0);
+ GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
+
+ dpct::queue_ptr stream = ctx.stream();
+
+ // Calculate execution configuration
+ const size_t shared_mem_size = C / H * 4 * sizeof(float); // For k, r, tf, td
+ sycl::range<3> block_dims(1, 1, C / H);
+ sycl::range<3> grid_dims(1, 1, B * H);
+
+ // Submit kernel
+ if (C / H == WKV_BLOCK_SIZE) {
+ stream->submit([&](sycl::handler& cgh) {
+ sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE>(
+ B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
+ item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
+ );
+ });
+ });
+ } else {
+ stream->submit([&](sycl::handler& cgh) {
+ sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE * 2>(
+ B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
+ item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
+ );
+ });
+ });
+ }
+
+ GGML_UNUSED(src0);
+ GGML_UNUSED(src1);
+}
+
+void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
+
+ const ggml_tensor *src0 = dst->src[0];
+ const ggml_tensor *src1 = dst->src[1];
+
+ const float* r_d = (const float*)dst->src[0]->data;
+ const float* w_d = (const float*)dst->src[1]->data;
+ const float* k_d = (const float*)dst->src[2]->data;
+ const float* v_d = (const float*)dst->src[3]->data;
+ const float* a_d = (const float*)dst->src[4]->data;
+ const float* b_d = (const float*)dst->src[5]->data;
+ const float* s_d = (const float*)dst->src[6]->data;
+ float* dst_d = (float*)dst->data;
+
+ const int64_t B = dst->src[6]->ne[1];
+ const int64_t T = dst->src[0]->ne[2];
+ const int64_t C = dst->ne[0];
+ const int64_t H = dst->src[0]->ne[1];
+
+ GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
+ GGML_ASSERT(C % H == 0);
+ GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2);
+
+ dpct::queue_ptr stream = ctx.stream();
+
+ // Calculate execution configuration
+ const size_t shared_mem_size = C / H * 5 * sizeof(float); // For r, w, k, a, b
+ sycl::range<3> block_dims(1, 1, C / H);
+ sycl::range<3> grid_dims(1, 1, B * H);
+
+ // Submit kernel
+ if (C / H == WKV_BLOCK_SIZE) {
+ stream->submit([&](sycl::handler& cgh) {
+ sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>(
+ B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
+ item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
+ );
+ });
+ });
+ } else {
+ stream->submit([&](sycl::handler& cgh) {
+ sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>(
+ B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
+ item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
+ );
+ });
+ });
+ }
+
+ GGML_UNUSED(src0);
+ GGML_UNUSED(src1);
+}
--- /dev/null
+#ifndef GGML_SYCL_WKV_HPP
+#define GGML_SYCL_WKV_HPP
+
+#include "common.hpp"
+
+void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
+void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
+#endif // GGML_SYCL_WKV_HPP
vk_pipeline pipeline_group_norm_f32;
vk_pipeline pipeline_rms_norm_f32;
vk_pipeline pipeline_rms_norm_back_f32;
+ vk_pipeline pipeline_l2_norm_f32;
vk_pipeline pipeline_gelu_f32;
vk_pipeline pipeline_gelu_quick_f32;
vk_pipeline pipeline_silu_f32;
vk_pipeline pipeline_timestep_embedding_f32;
vk_pipeline pipeline_pool2d_f32;
vk_pipeline pipeline_rwkv_wkv6_f32;
+ vk_pipeline pipeline_rwkv_wkv7_f32;
vk_pipeline pipeline_opt_step_adamw_f32;
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
uint32_t H;
};
+struct vk_op_rwkv_wkv7_push_constants {
+ uint32_t B;
+ uint32_t T;
+ uint32_t C;
+ uint32_t H;
+};
+
// Allow pre-recording command buffers
struct vk_staging_memcpy {
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
+
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
for (auto &c : compiles) {
return ctx->device->pipeline_rms_norm_back_f32;
}
return nullptr;
+ case GGML_OP_L2_NORM:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_l2_norm_f32;
+ }
+ return nullptr;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(dst)) {
case GGML_UNARY_OP_SILU:
return ctx->device->pipeline_rwkv_wkv6_f32;
}
return nullptr;
+ case GGML_OP_RWKV_WKV7:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_rwkv_wkv7_f32;
+ }
+ return nullptr;
case GGML_OP_OPT_STEP_ADAMW:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_opt_step_adamw_f32;
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
+ case GGML_OP_L2_NORM:
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_SUM_ROWS:
}, dryrun);
}
-static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) {
- const ggml_tensor * k = dst->src[0];
- const ggml_tensor * v = dst->src[1];
- const ggml_tensor * r = dst->src[2];
- const ggml_tensor * tf = dst->src[3];
- const ggml_tensor * td = dst->src[4];
- const ggml_tensor * state = dst->src[5];
-
- GGML_ASSERT(!ggml_is_quantized(k->type));
- GGML_ASSERT(!ggml_is_quantized(v->type));
- GGML_ASSERT(!ggml_is_quantized(r->type));
- GGML_ASSERT(!ggml_is_quantized(tf->type));
- GGML_ASSERT(!ggml_is_quantized(td->type));
- GGML_ASSERT(!ggml_is_quantized(state->type));
+static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
+ GGML_ASSERT(version == 6 || version == 7);
+ int num_srcs = version == 6 ? 6 : 7;
+
+ for (int i = 0; i < num_srcs; i++) {
+ GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
+ }
+
GGML_ASSERT(dst->buffer != nullptr);
- vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
GGML_ASSERT(pipeline != nullptr);
if (dryrun) {
}
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
- ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
- ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
- ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context;
- ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context;
- ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
- ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
+ ggml_backend_vk_buffer_context * src_buf_ctxs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
+ for (int i = 0; i < num_srcs; i++) {
+ src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
+ }
ggml_vk_sync_buffers(subctx);
- vk_buffer d_D = nullptr, d_K = nullptr, d_V = nullptr, d_R = nullptr, d_TF = nullptr, d_TD = nullptr, d_State = nullptr;
- size_t k_offset = 0, v_offset = 0, r_offset = 0, tf_offset = 0, td_offset = 0, state_offset = 0, dst_offset = 0;
- bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false;
+ vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
+ size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 };
+ bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false };
if (ctx->device->uma) {
- ggml_vk_host_get(ctx->device, k->data, d_K, k_offset);
- ggml_vk_host_get(ctx->device, v->data, d_V, v_offset);
- ggml_vk_host_get(ctx->device, r->data, d_R, r_offset);
- ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset);
- ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset);
- ggml_vk_host_get(ctx->device, state->data, d_State, state_offset);
- ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
+ for (int i = 0; i < num_srcs; i++) {
+ ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
+ srcs_uma[i] = d_srcs[i] != nullptr;
+ }
- K_uma = d_K != nullptr;
- V_uma = d_V != nullptr;
- R_uma = d_R != nullptr;
- TF_uma = d_TF != nullptr;
- TD_uma = d_TD != nullptr;
- STATE_uma = d_State != nullptr;
- DST_uma = d_D != nullptr;
+ ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
+ dst_uma = d_D != nullptr;
}
- if (!K_uma) {
- d_K = k_buf_ctx->dev_buffer;
- k_offset = vk_tensor_offset(k) + k->view_offs;
- }
- if (!V_uma) {
- d_V = v_buf_ctx->dev_buffer;
- v_offset = vk_tensor_offset(v) + v->view_offs;
- }
- if (!R_uma) {
- d_R = r_buf_ctx->dev_buffer;
- r_offset = vk_tensor_offset(r) + r->view_offs;
- }
- if (!TF_uma) {
- d_TF = tf_buf_ctx->dev_buffer;
- tf_offset = vk_tensor_offset(tf) + tf->view_offs;
- }
- if (!TD_uma) {
- d_TD = td_buf_ctx->dev_buffer;
- td_offset = vk_tensor_offset(td) + td->view_offs;
- }
- if (!STATE_uma) {
- d_State = state_buf_ctx->dev_buffer;
- state_offset = vk_tensor_offset(state) + state->view_offs;
+ uint64_t src_sizes[7] = { 0, 0, 0, 0, 0, 0, 0 };
+ for (int i = 0; i < num_srcs; i++) {
+ src_sizes[i] = ggml_nbytes(dst->src[i]);
+ if (!srcs_uma[i]) {
+ d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
+ src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
+ }
}
- if (!DST_uma) {
+
+ const uint64_t dst_size = ggml_nbytes(dst);
+ if (!dst_uma) {
d_D = dst_buf_ctx->dev_buffer;
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
}
- const uint64_t k_size = ggml_nbytes(k);
- const uint64_t v_size = ggml_nbytes(v);
- const uint64_t r_size = ggml_nbytes(r);
- const uint64_t tf_size = ggml_nbytes(tf);
- const uint64_t td_size = ggml_nbytes(td);
- const uint64_t state_size = ggml_nbytes(state);
- const uint64_t dst_size = ggml_nbytes(dst);
-
std::array<uint32_t, 3> elements = {
(uint32_t)(pc.B * pc.H),
1,
1
};
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
- vk_subbuffer{ d_K, k_offset, k_size },
- vk_subbuffer{ d_V, v_offset, v_size },
- vk_subbuffer{ d_R, r_offset, r_size },
- vk_subbuffer{ d_TF, tf_offset, tf_size },
- vk_subbuffer{ d_TD, td_offset, td_size },
- vk_subbuffer{ d_State, state_offset, state_size },
- vk_subbuffer{ d_D, dst_offset, dst_size }
- }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
+ if (version == 6) {
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
+ vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
+ vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
+ vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
+ vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
+ vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
+ vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
+ vk_subbuffer{ d_D, dst_offset, dst_size }
+ }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
+ } else if (version == 7) {
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
+ vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
+ vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
+ vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
+ vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
+ vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
+ vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
+ vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
+ vk_subbuffer{ d_D, dst_offset, dst_size }
+ }, sizeof(vk_op_rwkv_wkv7_push_constants), &pc, elements);
+ } else {
+ // shouldn't happen
+ GGML_ASSERT(false);
+ }
}
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
const size_t n_heads = dst->src[0]->ne[1];
const size_t n_seqs = dst->src[5]->ne[1];
- ggml_vk_op_f32_rwkv6(
+ ggml_vk_op_f32_wkv(
+ ctx, subctx, dst,
+ {
+ (uint32_t)n_seqs,
+ (uint32_t)seq_length,
+ (uint32_t)n_embed,
+ (uint32_t)n_heads,
+ },
+ 6,
+ dryrun
+ );
+}
+
+static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
+ const size_t seq_length = dst->src[0]->ne[2];
+ const size_t n_embed = dst->ne[0];
+ const size_t n_heads = dst->src[0]->ne[1];
+ const size_t n_seqs = dst->src[6]->ne[1];
+
+ ggml_vk_op_f32_wkv(
ctx, subctx, dst,
{
(uint32_t)n_seqs,
(uint32_t)n_embed,
(uint32_t)n_heads,
},
+ 7,
dryrun
);
}
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
}
+static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
+ float * op_params = (float *)dst->op_params;
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
+}
+
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
}
case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
+ case GGML_OP_L2_NORM:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
case GGML_OP_RWKV_WKV6:
+ case GGML_OP_RWKV_WKV7:
case GGML_OP_LEAKY_RELU:
case GGML_OP_FLASH_ATTN_EXT:
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
+ case GGML_OP_L2_NORM:
case GGML_OP_UNARY:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_RMS_NORM_BACK:
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
+ break;
+ case GGML_OP_L2_NORM:
+ ggml_vk_l2_norm(ctx, compute_ctx, src0, node, dryrun);
+
break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(node)) {
break;
+ case GGML_OP_RWKV_WKV7:
+ ggml_vk_rwkv_wkv7(ctx, compute_ctx, node, dryrun);
+
+ break;
+
case GGML_OP_OPT_STEP_ADAMW:
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
+ case GGML_OP_L2_NORM:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
case GGML_OP_RWKV_WKV6:
+ case GGML_OP_RWKV_WKV7:
case GGML_OP_LEAKY_RELU:
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
case GGML_OP_NORM:
case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
+ case GGML_OP_L2_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
case GGML_OP_RWKV_WKV6:
+ case GGML_OP_RWKV_WKV7:
case GGML_OP_LEAKY_RELU:
case GGML_OP_OPT_STEP_ADAMW:
return true;
tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
} else if (tensor->op == GGML_OP_SILU_BACK) {
tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
+ } else if (tensor->op == GGML_OP_L2_NORM) {
+ const float eps = ((float *) tensor->op_params)[0];
+ tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
} else if (tensor->op == GGML_OP_SOFT_MAX) {
if (src1 != nullptr) {
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
+ } else if (tensor->op == GGML_OP_RWKV_WKV7) {
+ tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
+ src_clone[4], src_clone[5], src_clone[6]);
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
src_clone[0]->flags = src0->flags;
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
--- /dev/null
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+#define BLOCK_SIZE 512
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+shared FLOAT_TYPE sum[BLOCK_SIZE];
+
+void main() {
+ const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
+ const uint tid = gl_LocalInvocationID.x;
+
+ sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
+
+ [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
+ const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
+ sum[tid] += xi * xi;
+ }
+
+ // sum up partial sums and write back result
+ barrier();
+ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ sum[tid] += sum[tid + s];
+ }
+ barrier();
+ }
+
+ const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1)));
+
+ [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
+ data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
+ }
+}
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
+ string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
+ string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
+
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
for (auto &c : compiles) {
--- /dev/null
+#version 450
+
+#extension GL_EXT_control_flow_attributes : require
+
+#define BLOCK_SIZE 64
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout(push_constant) uniform Parameters {
+ uint B;
+ uint T;
+ uint C;
+ uint H;
+};
+
+layout(binding = 0) readonly buffer RBuf { A_TYPE r[]; };
+layout(binding = 1) readonly buffer WBuf { A_TYPE w[]; };
+layout(binding = 2) readonly buffer KBuf { A_TYPE k[]; };
+layout(binding = 3) readonly buffer VBuf { A_TYPE v[]; };
+layout(binding = 4) readonly buffer ABuf { A_TYPE a[]; };
+layout(binding = 5) readonly buffer BBuf { A_TYPE b[]; };
+layout(binding = 6) readonly buffer StateBuf { A_TYPE state_in[]; };
+layout(binding = 7) buffer DstBuf { A_TYPE dst[]; };
+
+shared A_TYPE _r[BLOCK_SIZE], _w[BLOCK_SIZE], _k[BLOCK_SIZE], _a[BLOCK_SIZE], _b[BLOCK_SIZE];
+
+void main() {
+ const uint head_size = BLOCK_SIZE;
+ const uint batch_id = gl_WorkGroupID.x / H;
+ const uint head_id = gl_WorkGroupID.x % H;
+ const uint tid = gl_LocalInvocationID.x;
+
+ const uint state_size = C * head_size;
+ const uint n_seq_tokens = T / B;
+
+ if (batch_id >= B || head_id >= H) {
+ return;
+ }
+
+ A_TYPE state[BLOCK_SIZE];
+ [[unroll]] for (uint i = 0; i < head_size; i++) {
+ state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+ + tid * head_size + i];
+ }
+
+ const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
+ const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
+
+ for (uint t = start_t; t < end_t; t += C) {
+ barrier();
+ _r[tid] = r[t];
+ _w[tid] = w[t];
+ _k[tid] = k[t];
+ _a[tid] = a[t];
+ _b[tid] = b[t];
+ barrier();
+
+ A_TYPE sa = 0.0;
+ [[unroll]] for (uint j = 0; j < head_size; j += 4) {
+ vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
+ vec4 a_vec = vec4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
+ sa += dot(s_vec, a_vec);
+ }
+
+ const A_TYPE v_val = v[t];
+ A_TYPE y = 0.0;
+
+ [[unroll]] for (uint j = 0; j < head_size; j += 4) {
+ vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
+ vec4 w_vec = vec4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
+ vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
+ vec4 b_vec = vec4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
+ vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
+
+ vec4 kv = k_vec * v_val;
+ s_vec = s_vec * w_vec + kv + sa * b_vec;
+ y += dot(r_vec, s_vec);
+
+ state[j] = s_vec.x;
+ state[j+1] = s_vec.y;
+ state[j+2] = s_vec.z;
+ state[j+3] = s_vec.w;
+ }
+
+ dst[t] = y;
+ }
+
+ [[unroll]] for (uint i = 0; i < head_size; i++) {
+ dst[T * C + batch_id * state_size + head_id * head_size * head_size
+ + tid * head_size + i] = state[i];
+ }
+}
"RMS_NORM",
"RMS_NORM_BACK",
"GROUP_NORM",
+ "L2_NORM",
"MUL_MAT",
"MUL_MAT_ID",
"ADD_REL_POS",
"RWKV_WKV6",
"GATED_LINEAR_ATTN",
+ "RWKV_WKV7",
"UNARY",
"OPT_STEP_ADAMW",
};
-static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
+static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
"rms_norm(x)",
"rms_norm_back(x)",
"group_norm(x)",
+ "l2_norm(x)",
"X*Y",
"X[i]*Y",
"add_rel_pos(x)",
"rwkv_wkv6(k, v, r, tf, td, s)",
"gated_linear_attn(k, v, q, gate, s)",
+ "rwkv_wkv7(r, w, k, v, a, b, s)",
"unary(x)",
"adamw(x)",
};
-static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
+static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
}
+// ggml_l2_norm
+
+static struct ggml_tensor * ggml_l2_norm_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float eps,
+ bool inplace) {
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ ggml_set_op_params_f32(result, 0, eps);
+
+ result->op = GGML_OP_L2_NORM;
+ result->src[0] = a;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_l2_norm(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float eps) {
+ return ggml_l2_norm_impl(ctx, a, eps, false);
+}
+
+struct ggml_tensor * ggml_l2_norm_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float eps) {
+ return ggml_l2_norm_impl(ctx, a, eps, true);
+}
+
// ggml_mul_mat
static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
return result;
}
+// ggml_rwkv_wkv7
+
+struct ggml_tensor * ggml_rwkv_wkv7(
+ struct ggml_context * ctx,
+ struct ggml_tensor * r,
+ struct ggml_tensor * w,
+ struct ggml_tensor * k,
+ struct ggml_tensor * v,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * state) {
+ GGML_ASSERT(ggml_is_contiguous(r));
+ GGML_ASSERT(ggml_is_contiguous(w));
+ GGML_ASSERT(ggml_is_contiguous(k));
+ GGML_ASSERT(ggml_is_contiguous(v));
+ GGML_ASSERT(ggml_is_contiguous(a));
+ GGML_ASSERT(ggml_is_contiguous(b));
+ GGML_ASSERT(ggml_is_contiguous(state));
+
+ const int64_t S = k->ne[0];
+ const int64_t H = k->ne[1];
+ const int64_t n_tokens = k->ne[2];
+ const int64_t n_seqs = state->ne[1];
+ {
+ GGML_ASSERT(w->ne[0] == S && w->ne[1] == H && w->ne[2] == n_tokens);
+ GGML_ASSERT(k->ne[0] == S && k->ne[1] == H && k->ne[2] == n_tokens);
+ GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
+ GGML_ASSERT(a->ne[0] == S && a->ne[1] == H && a->ne[2] == n_tokens);
+ GGML_ASSERT(b->ne[0] == S && b->ne[1] == H && b->ne[2] == n_tokens);
+ GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
+ }
+
+ // concat output and new_state
+ const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+ result->op = GGML_OP_RWKV_WKV7;
+ result->src[0] = r;
+ result->src[1] = w;
+ result->src[2] = k;
+ result->src[3] = v;
+ result->src[4] = a;
+ result->src[5] = b;
+ result->src[6] = state;
+
+ return result;
+}
+
// ggml_unary
static struct ggml_tensor * ggml_unary_impl(