// TODO: optimize performance
inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
+inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
static const float GELU_COEF_A = 0.044715f;
static const float GELU_QUICK_COEF = -1.702f;
"WIN_UNPART",
"GET_REL_POS",
"ADD_REL_POS",
+ "RWKV_WKV",
"UNARY",
"CROSS_ENTROPY_LOSS_BACK",
};
-static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
+static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
"win_unpart(x)",
"get_rel_pos(x)",
"add_rel_pos(x)",
+ "rwkv_wkv(k, v, r, tf, td, s)",
"unary(x)",
"cross_entropy_loss_back(x,y)",
};
-static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
+static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
"SILU",
"HARDSWISH",
"HARDSIGMOID",
+ "EXP",
};
-static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13");
+static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSIGMOID);
}
+// ggml exp
+struct ggml_tensor * ggml_exp(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary(ctx, a, GGML_UNARY_OP_EXP);
+}
+
+struct ggml_tensor * ggml_exp_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
+}
+
// ggml_norm
static struct ggml_tensor * ggml_norm_impl(
return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
}
+// ggml_rwkv_wkv
+
+struct ggml_tensor * ggml_rwkv_wkv(
+ struct ggml_context * ctx,
+ struct ggml_tensor * k,
+ struct ggml_tensor * v,
+ struct ggml_tensor * r,
+ struct ggml_tensor * tf,
+ struct ggml_tensor * td,
+ struct ggml_tensor * state) {
+ GGML_ASSERT(ggml_is_contiguous(k));
+ GGML_ASSERT(ggml_is_contiguous(v));
+ GGML_ASSERT(ggml_is_contiguous(r));
+ GGML_ASSERT(ggml_is_contiguous(tf));
+ GGML_ASSERT(ggml_is_contiguous(td));
+ GGML_ASSERT(ggml_is_contiguous(state));
+
+ const int64_t S = k->ne[0];
+ const int64_t H = k->ne[2];
+ const int64_t n_tokens = k->ne[3];
+ const int64_t n_seqs = state->ne[1];
+ {
+ GGML_ASSERT(k->ne[1] == 1);
+ GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
+ GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
+ // TODO: RWKV v4 and v5
+ GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
+ GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
+ }
+
+ bool is_node = false;
+
+ if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad) {
+ GGML_ABORT("fatal error"); // TODO: implement backward
+ is_node = true;
+ }
+
+ // concat output and new_state
+ const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+ result->op = GGML_OP_RWKV_WKV;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = k;
+ result->src[1] = v;
+ result->src[2] = r;
+ result->src[3] = tf;
+ result->src[4] = td;
+ result->src[5] = state;
+
+ return result;
+}
+
// ggml_unary
static struct ggml_tensor * ggml_unary_impl(
}
}
+static void ggml_compute_forward_exp_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ assert(ggml_is_contiguous_1(src0));
+ assert(ggml_is_contiguous_1(dst));
+ assert(ggml_are_same_shape(src0, dst));
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_exp_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_exp(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_exp_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ABORT("fatal error");
+ }
+ }
+}
+
// ggml_compute_forward_norm
{
ggml_compute_forward_hardsigmoid(params, dst);
} break;
+ case GGML_UNARY_OP_EXP:
+ {
+ ggml_compute_forward_exp(params, dst);
+ } break;
default:
{
GGML_ABORT("fatal error");
}
}
+// ggml_compute_forward_rwkv_wkv
+
+static void ggml_compute_forward_rwkv_wkv_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+ const size_t T = dst->src[1]->ne[3];
+ const size_t C = dst->ne[0];
+ const size_t H = dst->src[1]->ne[2];
+ const size_t n_seqs = dst->src[5]->ne[1];
+
+ float * dst_data = (float *) dst->data;
+ float * state = ((float *) dst->data) + C * T;
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ memset(dst_data, 0, T * C * sizeof(float));
+
+ float * k = (float *) dst->src[0]->data;
+ float * v = (float *) dst->src[1]->data;
+ float * r = (float *) dst->src[2]->data;
+ float * time_faaaa = (float *) dst->src[3]->data;
+ float * time_decay = (float *) dst->src[4]->data;
+
+ size_t t_stride = H * (C / H);
+
+ size_t h_stride = C / H;
+ size_t h_stride_2d = (C / H) * (C / H);
+
+ // basically fused operations:
+ // dst = r @ (time_faaaa * (k @ v) + state),
+ // state = time_decay * state + (k @ v),
+ // recursive through each token
+ for (size_t t = 0; t < T; t++) {
+ size_t t_offset = t * t_stride;
+ size_t state_offset = (C / H) * C * (t / (T / n_seqs));
+ float * state_cur = state + state_offset;
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
+
+ for (size_t h = 0; h < H; h++) {
+ size_t h_offset = h * h_stride;
+ size_t t_h_offset = t_offset + h_offset;
+ size_t h_2d_offset = h * h_stride_2d;
+
+ for (size_t i = 0; i < C / H; i++) {
+ size_t t_h_i_offset = t_h_offset + i;
+ size_t h_i_offset = h_offset + i;
+ size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+ float k_val = k[t_h_i_offset];
+ float r_val = r[t_h_i_offset];
+ float time_faaaa_val = time_faaaa[h_i_offset];
+ // RWKV v6: different time_decay for each token.
+ float time_decay_val = time_decay[t_h_i_offset];
+
+ for (size_t j = 0; j < C / H; j ++) {
+ size_t t_h_j_offset = t_h_offset + j;
+ size_t h_2d_i_j_offset = h_2d_i_offset + j;
+
+ float v_val = v[t_h_j_offset];
+ float kv_val = v_val * k_val;
+ float prev_state_val = state_prev[h_2d_i_j_offset];
+ float temp_val = kv_val * time_faaaa_val + prev_state_val;
+ dst_data[t_h_j_offset] += temp_val * r_val;
+ state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
+ }
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_rwkv_wkv(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_rwkv_wkv_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ABORT("fatal error");
+ }
+ }
+}
+
// ggml_compute_forward_map_unary
static void ggml_compute_forward_map_unary_f32(
{
ggml_compute_forward_add_rel_pos(params, tensor);
} break;
+ case GGML_OP_RWKV_WKV:
+ {
+ ggml_compute_forward_rwkv_wkv(params, tensor);
+ } break;
case GGML_OP_MAP_UNARY:
{
ggml_unary_op_f32_t fun;
zero_table);
}
} break;
+ case GGML_UNARY_OP_EXP:
+ {
+ if (src0->grad) {
+ src0->grad = ggml_add_or_set(ctx,
+ src0->grad,
+ ggml_mul(ctx, tensor, tensor->grad),
+ zero_table);
+ }
+ } break;
default:
GGML_ABORT("fatal error");
}
} break;
case GGML_OP_GET_REL_POS:
case GGML_OP_ADD_REL_POS:
+ case GGML_OP_RWKV_WKV:
case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY:
case GGML_OP_MAP_CUSTOM1_F32:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_HARDSIGMOID:
+ case GGML_UNARY_OP_EXP:
{
n_tasks = 1;
} break;
case GGML_OP_WIN_PART:
case GGML_OP_WIN_UNPART:
case GGML_OP_GET_REL_POS:
+ case GGML_OP_RWKV_WKV:
case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY:
case GGML_OP_MAP_CUSTOM1_F32: